In [1]:
from os import read
import torch
from lietorch import SO3, SE3, LieGroupParameter

import argparse
import numpy as np
import time
import torch.optim as optim
import torch.nn.functional as F
import torch.nn as nn
import chamfer3D.dist_chamfer_3D
# import pytorch3d
from pdb import post_mortem, set_trace as bp

import itertools
from torch.autograd import Variable
import random
import torch.nn.functional as FF
from sklearn.cluster import kmeans_plusplus
import open3d as o3d
import os

Jitting Chamfer 3D
Loaded JIT 3D CUDA chamfer distance
Jupyter environment detected. Enabling Open3D WebVisualizer.
[Open3D INFO] WebRTC GUI backend enabled.
[Open3D INFO] WebRTCWindowSystem: HTTP handshake server disabled.


In [2]:
def read_obj(obj_path,for_open_mesh=False):
    with open(obj_path) as file:
        flag = 0
        points = []
        normals = []
        faces = []
        while 1:
            line = file.readline()
            if not line:
                break
            strs = line.split(" ")
            if strs[0] =='o' and flag==0:
                flag = 1
                continue
            elif strs[0]=='o' and flag==1:
                break
            if strs[0] =='v':
                points.append((float(strs[1]),float(strs[2]),float(strs[3])))
            
            if strs[0] == 'vn':
                normals.append((float(strs[1]),float(strs[2]),float(strs[3])))
            if strs[0] =='f':
                single_line_face = strs[1:]
                
#                 for sf in single_line_face:
#                     face_index = sf.split('/')
#                     faces.append((int(face_index[0]),int(face_index[1]),int(face_index[2])))
#                     break
                f_co = []
                for sf in single_line_face:
                    face_tmp = sf.split('/')[0]
                    f_co.append(face_tmp)
                if for_open_mesh == False:
                    if len(f_co)==3:
                        faces.append((int(f_co[0]),int(f_co[1]),int(f_co[2])))
                    elif len(f_co)==4:
                        faces.append((int(f_co[0]),int(f_co[1]),int(f_co[2])))
                        faces.append((int(f_co[0]),int(f_co[1]),int(f_co[3])))
                        faces.append((int(f_co[1]),int(f_co[2]),int(f_co[3])))
                        faces.append((int(f_co[0]),int(f_co[3]),int(f_co[2])))
                else:
                    faces.append([int(ver) for ver in f_co])
                
    points = np.array(points)

    normals = np.array(normals)
    faces = np.array(faces)
    return points,normals,faces



In [3]:
class LaplacianLoss(nn.Module):
    def __init__(self, vertex, faces, average=False):
        super(LaplacianLoss, self).__init__()
        self.nv = vertex.shape[0]
        self.nf = faces.shape[0]
        self.average = average
        laplacian = np.zeros([self.nv, self.nv]).astype(np.float32)

        laplacian[faces[:, 0], faces[:, 1]] = -1
        laplacian[faces[:, 1], faces[:, 0]] = -1
        laplacian[faces[:, 1], faces[:, 2]] = -1
        laplacian[faces[:, 2], faces[:, 1]] = -1
        laplacian[faces[:, 2], faces[:, 0]] = -1
        laplacian[faces[:, 0], faces[:, 2]] = -1

        r, c = np.diag_indices(laplacian.shape[0])
        laplacian[r, c] = -laplacian.sum(1)

        
        for i in range(self.nv):
            if laplacian[i, i]!=0: 
                laplacian[i, :] /= laplacian[i, i]

        self.register_buffer('laplacian', torch.from_numpy(laplacian))

    def forward(self, x):
        batch_size = x.shape[0]
        x = torch.matmul(self.laplacian.cuda(), x)
        dims = tuple(range(x.ndimension())[1:])
        x = x.pow(2).sum(dims)
        if self.average:
            return x.sum() / batch_size
        else:
            return x

In [4]:
if __name__ == '__main__':

    seed = 2000
    random.seed(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)

#     points_info,normals_info,face_info = read_obj('frame_000001.obj')
#     old_ones = np.ones((points_info.shape[0],1))
    
    
#     points_info = np.concatenate((points_info,old_ones),axis=1)
    all_base_list = []
    all_info_list = []
    
    for i in range(20,30):
        path = "frame_{:06d}.obj".format(i+1)

        tmp_base = read_obj(path)[0]
#         print(tmp_base)
        base_ones = np.ones((tmp_base.shape[0],1))
        tmp_base_new = np.concatenate((tmp_base,base_ones),axis=1)
        all_base_list.append(tmp_base_new)
        
    for i in range(21, 31):
        path = "frame_{:06d}.obj".format(i+1)

        tmp_target = read_obj(path)[0]
        target_ones = np.ones((tmp_target.shape[0],1))
        tmp_target_new = np.concatenate((tmp_target,target_ones),axis=1)
        all_info_list.append(tmp_target_new)


        
    points_info = tmp_target_new.copy()
    old_info = points_info
    
#     points_info = tmp_target_
#     la_loss = LaplacianLoss(torch.tensor(points_info),torch.tensor(face_info-1),average=True)
    
    all_base = np.stack(all_base_list)
    all_info = np.stack(all_info_list)
    


    base_info = all_base
    target_info = all_info
    
    # W,R= train(old_info=points_info,target_info=all_info)

    N = points_info.shape[0]
    F = target_info.shape[0]
    B = 40  ### 这个肯定可以通过读pkl得到

#     center, indexes = kmeans_plusplus(old_info[:, :-1], n_clusters=B, random_state=0)
    
    old_mesh = torch.tensor(old_info).float().cuda().detach()

#     center_t = torch.tensor(center).repeat(F, 1, 1).cuda()
    
    
    base_mesh = torch.tensor(base_info,requires_grad=False).float().cuda().detach()
    target_mesh = torch.tensor(target_info, requires_grad=False).float().cuda().detach()


    W = torch.randn((N, B), requires_grad=True, device="cuda")
    W2 = torch.randn((N, B), requires_grad=True, device="cuda")

    p1 = torch.randn((F, B, 7), requires_grad=True, device="cuda")

    optimizer = optim.Adam([W,  p1],lr=1e-2)

    chamLoss = chamfer3D.dist_chamfer_3D.chamfer_3DDist()
    for i in range(10000):
        # if i == 2000:
        #     optimizer = optim.Adam([W, p1],lr=1e-3)
        optimizer.zero_grad()
        loss = 0
        loss2 = 0
        f_idxs = random.sample(range(F), F)
        
        for f in f_idxs:
            T = torch.randn(B, 4, 4)
            W1 = FF.softmax(W * 10)
            W1 = (W1 / (W1.sum(1, keepdim=True).detach()))

            
            for b in range(B):

                R = SE3.InitFromVec(p1[f, b])
                # bp()
                T[b] = R.matrix()
            T = T.cuda()     

#             x = old_mesh
            x = base_mesh[f]
        
        
            bx = (T@x.T).permute(2, 0, 1)
            wbx = W1.unsqueeze(2) * bx
            # bp()
            wbx =  wbx.permute((1,0,2))
            # print("Wbx", wbx.shape)
            # wbx = wbx.sum(0, keepdim=True)[:, :, :-1]
            wbx = wbx.sum(0, keepdim=True)

            loss += torch.sum((target_mesh[f]-wbx)**2)

#         (loss + loss2).backward()
        loss.backward()
        # print("this is loss for epoch {} : {} ".format(i,loss.detach().cpu().numpy()))
        print("this is loss for epoch {} : {} ".format(i,loss / F))
        # print("p1: ", p1)
        optimizer.step()


    new_R = SE3.InitFromVec(p1)
    R = new_R.matrix()





this is loss for epoch 0 : 11188.337890625 
this is loss for epoch 1 : 10718.556640625 
this is loss for epoch 2 : 10265.974609375 
this is loss for epoch 3 : 9831.287109375 
this is loss for epoch 4 : 9414.4248046875 
this is loss for epoch 5 : 9015.1630859375 
this is loss for epoch 6 : 8633.2724609375 
this is loss for epoch 7 : 8268.279296875 


KeyboardInterrupt: 

In [11]:
color_list = []
number_list = [0,0.2,0.4,0.6,0.8,1]
for i in range(B):
    color = random.sample(number_list,3)
    color_list.append(color)
color_array = np.array(color_list)

In [12]:
W_new = W1.clone()
with torch.no_grad():
    for i in range(W1.shape[0]):
        max_idx = torch.topk(W1[i], 3)[1]
        for j in range(B):
            if j not in max_idx:
                W1[i, j] = 0
        W_new[i] = W1[i] / W1[i].sum()

In [13]:
W_c = W_new.detach().cpu().numpy()

In [14]:
(W_c@color_array).shape

(3199, 3)

In [15]:
import open3d as o3d
pcd = o3d.geometry.PointCloud()
# pcd2 = o3d.geometry.PointCloud()
pcd.points = o3d.utility.Vector3dVector(wbx_result[-1][:, :-1])
# pcd.points = o3d.utility.Vector3dVector(points_info)
pcd.colors = o3d.utility.Vector3dVector(W_c@color_array)
# pcd.normals = o3d.utility.Vector3dVector(normals_info)
# pcd2.paint_uniform_color([0, 0, 1])
o3d.visualization.draw_geometries([pcd])