In [5]:
from os import read
import torch
from lietorch import SO3, SE3, LieGroupParameter

import argparse
import numpy as np
import time
import torch.optim as optim
import torch.nn.functional as F
import torch.nn as nn
import chamfer3D.dist_chamfer_3D
# import pytorch3d
from pdb import post_mortem, set_trace as bp

import itertools
from torch.autograd import Variable
import random
import torch.nn.functional as FF
from sklearn.cluster import kmeans_plusplus
import open3d as o3d
import os

In [6]:
def read_obj(obj_path,for_open_mesh=False):
    with open(obj_path) as file:
        flag = 0
        points = []
        normals = []
        faces = []
        while 1:
            line = file.readline()
            if not line:
                break
            strs = line.split(" ")
            if strs[0] =='o' and flag==0:
                flag = 1
                continue
            elif strs[0]=='o' and flag==1:
                break
            if strs[0] =='v':
                points.append((float(strs[1]),float(strs[2]),float(strs[3])))
            
            if strs[0] == 'vn':
                normals.append((float(strs[1]),float(strs[2]),float(strs[3])))
            if strs[0] =='f':
                single_line_face = strs[1:]
                
#                 for sf in single_line_face:
#                     face_index = sf.split('/')
#                     faces.append((int(face_index[0]),int(face_index[1]),int(face_index[2])))
#                     break
                f_co = []
                for sf in single_line_face:
                    face_tmp = sf.split('/')[0]
                    f_co.append(face_tmp)
                if for_open_mesh == False:
                    if len(f_co)==3:
                        faces.append((int(f_co[0]),int(f_co[1]),int(f_co[2])))
                    elif len(f_co)==4:
                        faces.append((int(f_co[0]),int(f_co[1]),int(f_co[2])))
                        faces.append((int(f_co[0]),int(f_co[1]),int(f_co[3])))
                        faces.append((int(f_co[1]),int(f_co[2]),int(f_co[3])))
                        faces.append((int(f_co[0]),int(f_co[3]),int(f_co[2])))
                else:
                    faces.append([int(ver) for ver in f_co])
                
    points = np.array(points)

    normals = np.array(normals)
    faces = np.array(faces)
    return points,normals,faces



In [7]:
class LaplacianLoss(nn.Module):
    def __init__(self, vertex, faces, average=False):
        super(LaplacianLoss, self).__init__()
        self.nv = vertex.shape[0]
        self.nf = faces.shape[0]
        self.average = average
        laplacian = np.zeros([self.nv, self.nv]).astype(np.float32)

        laplacian[faces[:, 0], faces[:, 1]] = -1
        laplacian[faces[:, 1], faces[:, 0]] = -1
        laplacian[faces[:, 1], faces[:, 2]] = -1
        laplacian[faces[:, 2], faces[:, 1]] = -1
        laplacian[faces[:, 2], faces[:, 0]] = -1
        laplacian[faces[:, 0], faces[:, 2]] = -1

        r, c = np.diag_indices(laplacian.shape[0])
        laplacian[r, c] = -laplacian.sum(1)

        
        for i in range(self.nv):
            if laplacian[i, i]!=0: 
                laplacian[i, :] /= laplacian[i, i]

        self.register_buffer('laplacian', torch.from_numpy(laplacian))

    def forward(self, x):
        batch_size = x.shape[0]
        x = torch.matmul(self.laplacian.cuda(), x)
        dims = tuple(range(x.ndimension())[1:])
        x = x.pow(2).sum(dims)
        if self.average:
            return x.sum() / batch_size
        else:
            return x

In [4]:
if __name__ == '__main__':
    # parser = argparse.ArgumentParser()
    # parser.add_argument('--old', default='frame_000001.obj')
    # parser.add_argument('--target', default='frame_000002.obj')
    # args = parser.parse_args()

    # output_path = args.problem.replace('.g2o', '_rotavg.g2o')
    # input_pose_graph = read_g2o(args.problem)

    seed = 2000
    random.seed(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)

    points_info,normals_info,face_info = read_obj('frame_000001.obj')
    old_ones = np.ones((points_info.shape[0],1))
    
    
    points_info = np.concatenate((points_info,old_ones),axis=1)
    all_info_list = []
    for i in range(20, 30):
        path = "frame_{:06d}.obj".format(i+1)

        tmp_target = read_obj(path)[0]
        target_ones = np.ones((tmp_target.shape[0],1))
        tmp_target_new = np.concatenate((tmp_target,target_ones),axis=1)
        all_info_list.append(tmp_target_new)

        
    la_loss = LaplacianLoss(torch.tensor(points_info),torch.tensor(face_info-1),average=True)
    all_info = np.stack(all_info_list)
    old_info = points_info
    target_info = all_info
    # W,R= train(old_info=points_info,target_info=all_info)

    N = points_info.shape[0]
    F = target_info.shape[0]
    B = 40  ### 这个肯定可以通过读pkl得到

    center, indexes = kmeans_plusplus(old_info[:, :-1], n_clusters=B, random_state=0)
    old_mesh = torch.tensor(old_info).float().cuda().detach()
    # bp()
    # ones = torch.ones(old_mesh.shape[0], 1).cuda()
    # x = torch.cat([old_mesh, ones], dim=1)

    center_t = torch.tensor(center).repeat(F, 1, 1).cuda()
    target_mesh = torch.tensor(target_info, requires_grad=False).float().cuda().detach()

    # bp()
    # W = torch.randint(10,(N,B))
    # W = torch.randn((N,B)).cuda()
    # W = np.random.randint(10,size=(N,B))
    # bp()
    # W = SO3(torch.from_numpy(W).float().cuda())
    # W = torch.zeros((N, B), dtype=torch.float, requires_grad=True, device="cuda")
    W = torch.randn((N, B), requires_grad=True, device="cuda")
    W2 = torch.randn((N, B), requires_grad=True, device="cuda")
    # W = torch.tensor(W0.float(), requires_grad=True, device="cuda")
    # W = W * 0
    # bp()
    # W = LieGroupParameter(W)
    # bp()
    # R = torch.randint(10,(B,3,3))

    # bp()pcd.points = Vector3dVector(np_points)
    # random quaternion
   # bounding_box_max = old_mesh.max()
    # bounding_box_min = old_mesh.min()
    p1 = torch.randn((F, B, 7), requires_grad=True, device="cuda")
    # with torch.no_grad():
    #     p1[:, :, 4:] = center_t
    # W = W.cuda()
    # p1 = p1.cuda()
    # p1 = p1 / p1.norm(dim=-1, keepdim=True)
    # create SO3 object from quaternion (differentiable w.r.t q)
    
    # 4x4 transformation matrix (differentiable w.r.t R)
    

    # W = torch.tensor(np.load('weight.npy')).float().cuda()
    optimizer = optim.Adam([W,  p1],lr=1e-2)

    chamLoss = chamfer3D.dist_chamfer_3D.chamfer_3DDist()
    for i in range(10000):
        # if i == 2000:
        #     optimizer = optim.Adam([W, p1],lr=1e-3)
        optimizer.zero_grad()
        loss = 0
        loss2 = 0
        f_idxs = random.sample(range(F), F)
        
        for f in f_idxs:
            T = torch.randn(B, 4, 4)
            W1 = FF.softmax(W * 10)
            W1 = (W1 / (W1.sum(1, keepdim=True).detach()))
            # W3 = FF.softmax(W2)
            # W3 = (W3 / (W3.sum(1, keepdim=True).detach()))
            # dist= W3.unsqueeze(2) * old_mesh.unsqueeze(1).repeat(1, 45, 1)
            # dist = dist.sum(0)
            
            for b in range(B):
                # R = SO3.InitFromVec(p1[f, b])
                # bp()
                # bp()
                # with torch.no_grad():
                # temp = torch.cat([p1[f, b], dist[f][:-1]])
                R = SE3.InitFromVec(p1[f, b])
                # bp()
                T[b] = R.matrix()
            T = T.cuda()     
            # bp()      
            # bp() 
            # W1 = W
            # W1 = (W - W.min()) / (W.max() - W.min())
            
            # bp()
            # p1[:, :, 4:] = 
            # bp()
            # ones = torch.ones(old_mesh.shape[0], 1).cuda()
            # x = torch.cat([old_mesh, ones], dim=1)
            x = old_mesh
        
        
            bx = (T@x.T).permute(2, 0, 1)
            wbx = W1.unsqueeze(2) * bx
            # bp()
            wbx =  wbx.permute((1,0,2))
            # print("Wbx", wbx.shape)
            # wbx = wbx.sum(0, keepdim=True)[:, :, :-1]
            wbx = wbx.sum(0, keepdim=True)
            # loss_init = chamLoss(target_mesh[f][None],wbx)
            # loss += loss_init[0].sum()

#             bp()
            loss += torch.sum((target_mesh[f]-wbx)**2)
            # loss2 += la_loss(W1) * 100
            # loss += torch.sum((target_mesh[f]-wbx)**2) +la_loss(W.detach().cpu())
#             loss.backward()
            # print("loss",loss)
            # bp()
        # bp()
        (loss + loss2).backward()
        # print("this is loss for epoch {} : {} ".format(i,loss.detach().cpu().numpy()))
        print("this is loss for epoch {} : {}, {} ".format(i,loss / F, loss2/F))
        # print("p1: ", p1)
        optimizer.step()


    new_R = SE3.InitFromVec(p1)
    R = new_R.matrix()

    # W1 = (W - W.min()) / (W.max() - W.min())
    # W1 = (W1 / (W1.sum(1, keepdim=True).detach()))
    # bp()
    # return R.detach(), W1.detach()



this is loss for epoch 0 : 10621.0302734375, 0.0 
this is loss for epoch 1 : 10201.5625, 0.0 
this is loss for epoch 2 : 9797.103515625, 0.0 
this is loss for epoch 3 : 9407.7861328125, 0.0 
this is loss for epoch 4 : 9033.6826171875, 0.0 
this is loss for epoch 5 : 8674.6005859375, 0.0 
this is loss for epoch 6 : 8330.240234375, 0.0 
this is loss for epoch 7 : 8000.20654296875, 0.0 
this is loss for epoch 8 : 7684.0009765625, 0.0 
this is loss for epoch 9 : 7381.02294921875, 0.0 
this is loss for epoch 10 : 7090.6494140625, 0.0 
this is loss for epoch 11 : 6812.24169921875, 0.0 
this is loss for epoch 12 : 6545.09912109375, 0.0 
this is loss for epoch 13 : 6288.43115234375, 0.0 
this is loss for epoch 14 : 6041.42041015625, 0.0 
this is loss for epoch 15 : 5803.32373046875, 0.0 
this is loss for epoch 16 : 5573.54541015625, 0.0 
this is loss for epoch 17 : 5351.662109375, 0.0 
this is loss for epoch 18 : 5137.42724609375, 0.0 
this is loss for epoch 19 : 4930.75048828125, 0.0 
this is

this is loss for epoch 158 : 36.410213470458984, 0.0 
this is loss for epoch 159 : 35.95897674560547, 0.0 
this is loss for epoch 160 : 35.51527786254883, 0.0 
this is loss for epoch 161 : 35.07224655151367, 0.0 
this is loss for epoch 162 : 34.636226654052734, 0.0 
this is loss for epoch 163 : 34.21841049194336, 0.0 
this is loss for epoch 164 : 33.81858825683594, 0.0 
this is loss for epoch 165 : 33.42978286743164, 0.0 
this is loss for epoch 166 : 33.052490234375, 0.0 
this is loss for epoch 167 : 32.68193435668945, 0.0 
this is loss for epoch 168 : 32.31393051147461, 0.0 
this is loss for epoch 169 : 31.951309204101562, 0.0 
this is loss for epoch 170 : 31.60064697265625, 0.0 
this is loss for epoch 171 : 31.267438888549805, 0.0 
this is loss for epoch 172 : 30.961530685424805, 0.0 
this is loss for epoch 173 : 30.686193466186523, 0.0 
this is loss for epoch 174 : 30.429296493530273, 0.0 
this is loss for epoch 175 : 30.174530029296875, 0.0 
this is loss for epoch 176 : 29.91326904

this is loss for epoch 311 : 14.613351821899414, 0.0 
this is loss for epoch 312 : 14.560078620910645, 0.0 
this is loss for epoch 313 : 14.506464004516602, 0.0 
this is loss for epoch 314 : 14.45256519317627, 0.0 
this is loss for epoch 315 : 14.398126602172852, 0.0 
this is loss for epoch 316 : 14.343317031860352, 0.0 
this is loss for epoch 317 : 14.288042068481445, 0.0 
this is loss for epoch 318 : 14.23188304901123, 0.0 
this is loss for epoch 319 : 14.175396919250488, 0.0 
this is loss for epoch 320 : 14.11927318572998, 0.0 
this is loss for epoch 321 : 14.063422203063965, 0.0 
this is loss for epoch 322 : 14.008699417114258, 0.0 
this is loss for epoch 323 : 13.954690933227539, 0.0 
this is loss for epoch 324 : 13.902055740356445, 0.0 
this is loss for epoch 325 : 13.853045463562012, 0.0 
this is loss for epoch 326 : 13.801360130310059, 0.0 
this is loss for epoch 327 : 13.747629165649414, 0.0 
this is loss for epoch 328 : 13.69444751739502, 0.0 
this is loss for epoch 329 : 13.

this is loss for epoch 464 : 9.098176956176758, 0.0 
this is loss for epoch 465 : 9.078062057495117, 0.0 
this is loss for epoch 466 : 9.057596206665039, 0.0 
this is loss for epoch 467 : 9.037481307983398, 0.0 
this is loss for epoch 468 : 9.017661094665527, 0.0 
this is loss for epoch 469 : 8.998431205749512, 0.0 
this is loss for epoch 470 : 8.979658126831055, 0.0 
this is loss for epoch 471 : 8.961170196533203, 0.0 
this is loss for epoch 472 : 8.942598342895508, 0.0 
this is loss for epoch 473 : 8.92423152923584, 0.0 
this is loss for epoch 474 : 8.906206130981445, 0.0 
this is loss for epoch 475 : 8.88769245147705, 0.0 
this is loss for epoch 476 : 8.869561195373535, 0.0 
this is loss for epoch 477 : 8.85083293914795, 0.0 
this is loss for epoch 478 : 8.831855773925781, 0.0 
this is loss for epoch 479 : 8.812976837158203, 0.0 
this is loss for epoch 480 : 8.79391098022461, 0.0 
this is loss for epoch 481 : 8.775126457214355, 0.0 
this is loss for epoch 482 : 8.756684303283691, 0.

this is loss for epoch 619 : 6.817063808441162, 0.0 
this is loss for epoch 620 : 6.800783634185791, 0.0 
this is loss for epoch 621 : 6.786895275115967, 0.0 
this is loss for epoch 622 : 6.774596691131592, 0.0 
this is loss for epoch 623 : 6.7633256912231445, 0.0 
this is loss for epoch 624 : 6.752620220184326, 0.0 
this is loss for epoch 625 : 6.74228048324585, 0.0 
this is loss for epoch 626 : 6.732032299041748, 0.0 
this is loss for epoch 627 : 6.721748352050781, 0.0 
this is loss for epoch 628 : 6.7114129066467285, 0.0 
this is loss for epoch 629 : 6.701120853424072, 0.0 
this is loss for epoch 630 : 6.690834045410156, 0.0 
this is loss for epoch 631 : 6.680589199066162, 0.0 
this is loss for epoch 632 : 6.670231819152832, 0.0 
this is loss for epoch 633 : 6.659879207611084, 0.0 
this is loss for epoch 634 : 6.649445533752441, 0.0 
this is loss for epoch 635 : 6.638950347900391, 0.0 
this is loss for epoch 636 : 6.628363132476807, 0.0 
this is loss for epoch 637 : 6.61775064468383

this is loss for epoch 774 : 5.427468776702881, 0.0 
this is loss for epoch 775 : 5.42022705078125, 0.0 
this is loss for epoch 776 : 5.413201808929443, 0.0 
this is loss for epoch 777 : 5.406217575073242, 0.0 
this is loss for epoch 778 : 5.399168491363525, 0.0 
this is loss for epoch 779 : 5.391892910003662, 0.0 
this is loss for epoch 780 : 5.384554862976074, 0.0 
this is loss for epoch 781 : 5.377361297607422, 0.0 
this is loss for epoch 782 : 5.37028694152832, 0.0 
this is loss for epoch 783 : 5.363219738006592, 0.0 
this is loss for epoch 784 : 5.356138229370117, 0.0 
this is loss for epoch 785 : 5.349147319793701, 0.0 
this is loss for epoch 786 : 5.342105388641357, 0.0 
this is loss for epoch 787 : 5.335171222686768, 0.0 
this is loss for epoch 788 : 5.328269004821777, 0.0 
this is loss for epoch 789 : 5.321323394775391, 0.0 
this is loss for epoch 790 : 5.314348220825195, 0.0 
this is loss for epoch 791 : 5.307322978973389, 0.0 
this is loss for epoch 792 : 5.300360202789307, 

this is loss for epoch 929 : 4.457954406738281, 0.0 
this is loss for epoch 930 : 4.452788829803467, 0.0 
this is loss for epoch 931 : 4.447473049163818, 0.0 
this is loss for epoch 932 : 4.442323207855225, 0.0 
this is loss for epoch 933 : 4.437268257141113, 0.0 
this is loss for epoch 934 : 4.432046413421631, 0.0 
this is loss for epoch 935 : 4.426812171936035, 0.0 
this is loss for epoch 936 : 4.421710968017578, 0.0 
this is loss for epoch 937 : 4.416634559631348, 0.0 
this is loss for epoch 938 : 4.411435604095459, 0.0 
this is loss for epoch 939 : 4.406282901763916, 0.0 
this is loss for epoch 940 : 4.401198863983154, 0.0 
this is loss for epoch 941 : 4.3960161209106445, 0.0 
this is loss for epoch 942 : 4.3908891677856445, 0.0 
this is loss for epoch 943 : 4.385791301727295, 0.0 
this is loss for epoch 944 : 4.380681037902832, 0.0 
this is loss for epoch 945 : 4.375591278076172, 0.0 
this is loss for epoch 946 : 4.370473384857178, 0.0 
this is loss for epoch 947 : 4.3653101921081

this is loss for epoch 1082 : 3.7414398193359375, 0.0 
this is loss for epoch 1083 : 3.7374393939971924, 0.0 
this is loss for epoch 1084 : 3.733509063720703, 0.0 
this is loss for epoch 1085 : 3.7295684814453125, 0.0 
this is loss for epoch 1086 : 3.725571870803833, 0.0 
this is loss for epoch 1087 : 3.721583127975464, 0.0 
this is loss for epoch 1088 : 3.7176291942596436, 0.0 
this is loss for epoch 1089 : 3.7137534618377686, 0.0 
this is loss for epoch 1090 : 3.7098395824432373, 0.0 
this is loss for epoch 1091 : 3.7059104442596436, 0.0 
this is loss for epoch 1092 : 3.701984167098999, 0.0 
this is loss for epoch 1093 : 3.6980628967285156, 0.0 
this is loss for epoch 1094 : 3.6941475868225098, 0.0 
this is loss for epoch 1095 : 3.6902663707733154, 0.0 
this is loss for epoch 1096 : 3.686354875564575, 0.0 
this is loss for epoch 1097 : 3.682454824447632, 0.0 
this is loss for epoch 1098 : 3.6785778999328613, 0.0 
this is loss for epoch 1099 : 3.6746551990509033, 0.0 
this is loss for

this is loss for epoch 1233 : 3.1690967082977295, 0.0 
this is loss for epoch 1234 : 3.166050672531128, 0.0 
this is loss for epoch 1235 : 3.1628315448760986, 0.0 
this is loss for epoch 1236 : 3.1597847938537598, 0.0 
this is loss for epoch 1237 : 3.156592607498169, 0.0 
this is loss for epoch 1238 : 3.153430700302124, 0.0 
this is loss for epoch 1239 : 3.1502556800842285, 0.0 
this is loss for epoch 1240 : 3.1471474170684814, 0.0 
this is loss for epoch 1241 : 3.1440494060516357, 0.0 
this is loss for epoch 1242 : 3.1409084796905518, 0.0 
this is loss for epoch 1243 : 3.1378064155578613, 0.0 
this is loss for epoch 1244 : 3.134631872177124, 0.0 
this is loss for epoch 1245 : 3.1315200328826904, 0.0 
this is loss for epoch 1246 : 3.128399133682251, 0.0 
this is loss for epoch 1247 : 3.1252965927124023, 0.0 
this is loss for epoch 1248 : 3.122185230255127, 0.0 
this is loss for epoch 1249 : 3.1189534664154053, 0.0 
this is loss for epoch 1250 : 3.1158487796783447, 0.0 
this is loss for

this is loss for epoch 1384 : 2.724700450897217, 0.0 
this is loss for epoch 1385 : 2.7220261096954346, 0.0 
this is loss for epoch 1386 : 2.7193286418914795, 0.0 
this is loss for epoch 1387 : 2.7167327404022217, 0.0 
this is loss for epoch 1388 : 2.714167356491089, 0.0 
this is loss for epoch 1389 : 2.7115542888641357, 0.0 
this is loss for epoch 1390 : 2.708930253982544, 0.0 
this is loss for epoch 1391 : 2.7062673568725586, 0.0 
this is loss for epoch 1392 : 2.7035720348358154, 0.0 
this is loss for epoch 1393 : 2.700904607772827, 0.0 
this is loss for epoch 1394 : 2.6982762813568115, 0.0 
this is loss for epoch 1395 : 2.6957428455352783, 0.0 
this is loss for epoch 1396 : 2.693089723587036, 0.0 
this is loss for epoch 1397 : 2.690248727798462, 0.0 
this is loss for epoch 1398 : 2.687678575515747, 0.0 
this is loss for epoch 1399 : 2.68510365486145, 0.0 
this is loss for epoch 1400 : 2.6825597286224365, 0.0 
this is loss for epoch 1401 : 2.6799733638763428, 0.0 
this is loss for ep

this is loss for epoch 1535 : 2.358656167984009, 0.0 
this is loss for epoch 1536 : 2.3564727306365967, 0.0 
this is loss for epoch 1537 : 2.3541760444641113, 0.0 
this is loss for epoch 1538 : 2.3520076274871826, 0.0 
this is loss for epoch 1539 : 2.3499302864074707, 0.0 
this is loss for epoch 1540 : 2.3476827144622803, 0.0 
this is loss for epoch 1541 : 2.3453526496887207, 0.0 
this is loss for epoch 1542 : 2.3430838584899902, 0.0 
this is loss for epoch 1543 : 2.340891122817993, 0.0 
this is loss for epoch 1544 : 2.3387017250061035, 0.0 
this is loss for epoch 1545 : 2.33657169342041, 0.0 
this is loss for epoch 1546 : 2.334439516067505, 0.0 
this is loss for epoch 1547 : 2.332272529602051, 0.0 
this is loss for epoch 1548 : 2.330048084259033, 0.0 
this is loss for epoch 1549 : 2.3278133869171143, 0.0 
this is loss for epoch 1550 : 2.3256824016571045, 0.0 
this is loss for epoch 1551 : 2.3235433101654053, 0.0 
this is loss for epoch 1552 : 2.321336507797241, 0.0 
this is loss for e

this is loss for epoch 1686 : 2.0260438919067383, 0.0 
this is loss for epoch 1687 : 2.024301290512085, 0.0 
this is loss for epoch 1688 : 2.0222222805023193, 0.0 
this is loss for epoch 1689 : 2.0201239585876465, 0.0 
this is loss for epoch 1690 : 2.0180938243865967, 0.0 
this is loss for epoch 1691 : 2.0162177085876465, 0.0 
this is loss for epoch 1692 : 2.01419734954834, 0.0 
this is loss for epoch 1693 : 2.0121638774871826, 0.0 
this is loss for epoch 1694 : 2.0102436542510986, 0.0 
this is loss for epoch 1695 : 2.008255958557129, 0.0 
this is loss for epoch 1696 : 2.006317615509033, 0.0 
this is loss for epoch 1697 : 2.004483461380005, 0.0 
this is loss for epoch 1698 : 2.002518892288208, 0.0 
this is loss for epoch 1699 : 2.0005710124969482, 0.0 
this is loss for epoch 1700 : 1.9986648559570312, 0.0 
this is loss for epoch 1701 : 1.996711015701294, 0.0 
this is loss for epoch 1702 : 1.9946975708007812, 0.0 
this is loss for epoch 1703 : 1.9926722049713135, 0.0 
this is loss for e

this is loss for epoch 1836 : 1.738885521888733, 0.0 
this is loss for epoch 1837 : 1.7371505498886108, 0.0 
this is loss for epoch 1838 : 1.7354522943496704, 0.0 
this is loss for epoch 1839 : 1.7336915731430054, 0.0 
this is loss for epoch 1840 : 1.732033371925354, 0.0 
this is loss for epoch 1841 : 1.7303203344345093, 0.0 
this is loss for epoch 1842 : 1.72854483127594, 0.0 
this is loss for epoch 1843 : 1.7268356084823608, 0.0 
this is loss for epoch 1844 : 1.7251437902450562, 0.0 
this is loss for epoch 1845 : 1.7235323190689087, 0.0 
this is loss for epoch 1846 : 1.7219946384429932, 0.0 
this is loss for epoch 1847 : 1.7204620838165283, 0.0 
this is loss for epoch 1848 : 1.718622088432312, 0.0 
this is loss for epoch 1849 : 1.716719627380371, 0.0 
this is loss for epoch 1850 : 1.7150840759277344, 0.0 
this is loss for epoch 1851 : 1.713476538658142, 0.0 
this is loss for epoch 1852 : 1.7117022275924683, 0.0 
this is loss for epoch 1853 : 1.7100738286972046, 0.0 
this is loss for 

this is loss for epoch 1986 : 1.502102017402649, 0.0 
this is loss for epoch 1987 : 1.5004535913467407, 0.0 
this is loss for epoch 1988 : 1.4988082647323608, 0.0 
this is loss for epoch 1989 : 1.4976338148117065, 0.0 
this is loss for epoch 1990 : 1.495895266532898, 0.0 
this is loss for epoch 1991 : 1.4943565130233765, 0.0 
this is loss for epoch 1992 : 1.4930089712142944, 0.0 
this is loss for epoch 1993 : 1.4913080930709839, 0.0 
this is loss for epoch 1994 : 1.4898674488067627, 0.0 
this is loss for epoch 1995 : 1.4885168075561523, 0.0 
this is loss for epoch 1996 : 1.4870268106460571, 0.0 
this is loss for epoch 1997 : 1.4856001138687134, 0.0 
this is loss for epoch 1998 : 1.4843136072158813, 0.0 
this is loss for epoch 1999 : 1.4828184843063354, 0.0 
this is loss for epoch 2000 : 1.481317400932312, 0.0 
this is loss for epoch 2001 : 1.479990005493164, 0.0 
this is loss for epoch 2002 : 1.4786494970321655, 0.0 
this is loss for epoch 2003 : 1.477147102355957, 0.0 
this is loss fo

this is loss for epoch 2136 : 1.304447054862976, 0.0 
this is loss for epoch 2137 : 1.3033303022384644, 0.0 
this is loss for epoch 2138 : 1.3019884824752808, 0.0 
this is loss for epoch 2139 : 1.3007867336273193, 0.0 
this is loss for epoch 2140 : 1.299495816230774, 0.0 
this is loss for epoch 2141 : 1.2984025478363037, 0.0 
this is loss for epoch 2142 : 1.297262191772461, 0.0 
this is loss for epoch 2143 : 1.296061396598816, 0.0 
this is loss for epoch 2144 : 1.2948824167251587, 0.0 
this is loss for epoch 2145 : 1.2937371730804443, 0.0 
this is loss for epoch 2146 : 1.2926186323165894, 0.0 
this is loss for epoch 2147 : 1.2914609909057617, 0.0 
this is loss for epoch 2148 : 1.2903187274932861, 0.0 
this is loss for epoch 2149 : 1.2891424894332886, 0.0 
this is loss for epoch 2150 : 1.2879836559295654, 0.0 
this is loss for epoch 2151 : 1.2867851257324219, 0.0 
this is loss for epoch 2152 : 1.285643458366394, 0.0 
this is loss for epoch 2153 : 1.2845276594161987, 0.0 
this is loss fo

this is loss for epoch 2286 : 1.1512993574142456, 0.0 
this is loss for epoch 2287 : 1.1503385305404663, 0.0 
this is loss for epoch 2288 : 1.1492884159088135, 0.0 
this is loss for epoch 2289 : 1.148269772529602, 0.0 
this is loss for epoch 2290 : 1.1472855806350708, 0.0 
this is loss for epoch 2291 : 1.1463056802749634, 0.0 
this is loss for epoch 2292 : 1.1453431844711304, 0.0 
this is loss for epoch 2293 : 1.1443853378295898, 0.0 
this is loss for epoch 2294 : 1.1433948278427124, 0.0 
this is loss for epoch 2295 : 1.1423726081848145, 0.0 
this is loss for epoch 2296 : 1.1414047479629517, 0.0 
this is loss for epoch 2297 : 1.140468716621399, 0.0 
this is loss for epoch 2298 : 1.1395378112792969, 0.0 
this is loss for epoch 2299 : 1.1385711431503296, 0.0 
this is loss for epoch 2300 : 1.1376053094863892, 0.0 
this is loss for epoch 2301 : 1.1367405652999878, 0.0 
this is loss for epoch 2302 : 1.136290192604065, 0.0 
this is loss for epoch 2303 : 1.1352474689483643, 0.0 
this is loss 

this is loss for epoch 2436 : 1.0225812196731567, 0.0 
this is loss for epoch 2437 : 1.0217944383621216, 0.0 
this is loss for epoch 2438 : 1.0209563970565796, 0.0 
this is loss for epoch 2439 : 1.0201362371444702, 0.0 
this is loss for epoch 2440 : 1.0193893909454346, 0.0 
this is loss for epoch 2441 : 1.0186246633529663, 0.0 
this is loss for epoch 2442 : 1.0178663730621338, 0.0 
this is loss for epoch 2443 : 1.0171699523925781, 0.0 
this is loss for epoch 2444 : 1.0164519548416138, 0.0 
this is loss for epoch 2445 : 1.0157297849655151, 0.0 
this is loss for epoch 2446 : 1.0149741172790527, 0.0 
this is loss for epoch 2447 : 1.014245867729187, 0.0 
this is loss for epoch 2448 : 1.0135115385055542, 0.0 
this is loss for epoch 2449 : 1.012757420539856, 0.0 
this is loss for epoch 2450 : 1.011962652206421, 0.0 
this is loss for epoch 2451 : 1.0111687183380127, 0.0 
this is loss for epoch 2452 : 1.0104124546051025, 0.0 
this is loss for epoch 2453 : 1.0096920728683472, 0.0 
this is loss 

this is loss for epoch 2586 : 0.9066033363342285, 0.0 
this is loss for epoch 2587 : 0.9060032963752747, 0.0 
this is loss for epoch 2588 : 0.9053751826286316, 0.0 
this is loss for epoch 2589 : 0.9047166705131531, 0.0 
this is loss for epoch 2590 : 0.9040835499763489, 0.0 
this is loss for epoch 2591 : 0.9034824371337891, 0.0 
this is loss for epoch 2592 : 0.9028858542442322, 0.0 
this is loss for epoch 2593 : 0.9022696614265442, 0.0 
this is loss for epoch 2594 : 0.9016386866569519, 0.0 
this is loss for epoch 2595 : 0.9010011553764343, 0.0 
this is loss for epoch 2596 : 0.900390625, 0.0 
this is loss for epoch 2597 : 0.8998357057571411, 0.0 
this is loss for epoch 2598 : 0.8993242383003235, 0.0 
this is loss for epoch 2599 : 0.8987835049629211, 0.0 
this is loss for epoch 2600 : 0.8982211947441101, 0.0 
this is loss for epoch 2601 : 0.8976812362670898, 0.0 
this is loss for epoch 2602 : 0.8972072601318359, 0.0 
this is loss for epoch 2603 : 0.8967165350914001, 0.0 
this is loss for 

this is loss for epoch 2736 : 0.8151394724845886, 0.0 
this is loss for epoch 2737 : 0.8145089149475098, 0.0 
this is loss for epoch 2738 : 0.8138778805732727, 0.0 
this is loss for epoch 2739 : 0.8132902383804321, 0.0 
this is loss for epoch 2740 : 0.812714695930481, 0.0 
this is loss for epoch 2741 : 0.8121399879455566, 0.0 
this is loss for epoch 2742 : 0.8116335868835449, 0.0 
this is loss for epoch 2743 : 0.8110403418540955, 0.0 
this is loss for epoch 2744 : 0.8105164766311646, 0.0 
this is loss for epoch 2745 : 0.8099417090415955, 0.0 
this is loss for epoch 2746 : 0.8093628883361816, 0.0 
this is loss for epoch 2747 : 0.8088691830635071, 0.0 
this is loss for epoch 2748 : 0.8083650469779968, 0.0 
this is loss for epoch 2749 : 0.8078140616416931, 0.0 
this is loss for epoch 2750 : 0.8072969317436218, 0.0 
this is loss for epoch 2751 : 0.8067770004272461, 0.0 
this is loss for epoch 2752 : 0.8062516450881958, 0.0 
this is loss for epoch 2753 : 0.8057605624198914, 0.0 
this is los

this is loss for epoch 2886 : 0.7365545630455017, 0.0 
this is loss for epoch 2887 : 0.7360832095146179, 0.0 
this is loss for epoch 2888 : 0.7355929613113403, 0.0 
this is loss for epoch 2889 : 0.735093891620636, 0.0 
this is loss for epoch 2890 : 0.7346128821372986, 0.0 
this is loss for epoch 2891 : 0.7341344952583313, 0.0 
this is loss for epoch 2892 : 0.7336652278900146, 0.0 
this is loss for epoch 2893 : 0.733225405216217, 0.0 
this is loss for epoch 2894 : 0.7327920794487, 0.0 
this is loss for epoch 2895 : 0.732369065284729, 0.0 
this is loss for epoch 2896 : 0.7319750189781189, 0.0 
this is loss for epoch 2897 : 0.7316129803657532, 0.0 
this is loss for epoch 2898 : 0.7312765717506409, 0.0 
this is loss for epoch 2899 : 0.7309654355049133, 0.0 
this is loss for epoch 2900 : 0.7307077646255493, 0.0 
this is loss for epoch 2901 : 0.7305139303207397, 0.0 
this is loss for epoch 2902 : 0.7303723692893982, 0.0 
this is loss for epoch 2903 : 0.7302622199058533, 0.0 
this is loss for

this is loss for epoch 3036 : 0.6730782985687256, 0.0 
this is loss for epoch 3037 : 0.6719844937324524, 0.0 
this is loss for epoch 3038 : 0.6704999208450317, 0.0 
this is loss for epoch 3039 : 0.6689414381980896, 0.0 
this is loss for epoch 3040 : 0.6675497889518738, 0.0 
this is loss for epoch 3041 : 0.6664909720420837, 0.0 
this is loss for epoch 3042 : 0.66584712266922, 0.0 
this is loss for epoch 3043 : 0.6655095219612122, 0.0 
this is loss for epoch 3044 : 0.665270209312439, 0.0 
this is loss for epoch 3045 : 0.6649156808853149, 0.0 
this is loss for epoch 3046 : 0.6643595695495605, 0.0 
this is loss for epoch 3047 : 0.6635967493057251, 0.0 
this is loss for epoch 3048 : 0.6627640724182129, 0.0 
this is loss for epoch 3049 : 0.6619881391525269, 0.0 
this is loss for epoch 3050 : 0.6613566279411316, 0.0 
this is loss for epoch 3051 : 0.660817563533783, 0.0 
this is loss for epoch 3052 : 0.6603574156761169, 0.0 
this is loss for epoch 3053 : 0.6599699258804321, 0.0 
this is loss f

this is loss for epoch 3186 : 0.6070224642753601, 0.0 
this is loss for epoch 3187 : 0.6066058874130249, 0.0 
this is loss for epoch 3188 : 0.6062588095664978, 0.0 
this is loss for epoch 3189 : 0.6059606075286865, 0.0 
this is loss for epoch 3190 : 0.605659544467926, 0.0 
this is loss for epoch 3191 : 0.6053730845451355, 0.0 
this is loss for epoch 3192 : 0.6050798296928406, 0.0 
this is loss for epoch 3193 : 0.6048305034637451, 0.0 
this is loss for epoch 3194 : 0.6046218276023865, 0.0 
this is loss for epoch 3195 : 0.6044657826423645, 0.0 
this is loss for epoch 3196 : 0.6043070554733276, 0.0 
this is loss for epoch 3197 : 0.6042062640190125, 0.0 
this is loss for epoch 3198 : 0.6041799783706665, 0.0 
this is loss for epoch 3199 : 0.6042375564575195, 0.0 
this is loss for epoch 3200 : 0.6043127775192261, 0.0 
this is loss for epoch 3201 : 0.6044062376022339, 0.0 
this is loss for epoch 3202 : 0.6043865084648132, 0.0 
this is loss for epoch 3203 : 0.6042593717575073, 0.0 
this is los

this is loss for epoch 3336 : 0.5496124625205994, 0.0 
this is loss for epoch 3337 : 0.549113392829895, 0.0 
this is loss for epoch 3338 : 0.5486355423927307, 0.0 
this is loss for epoch 3339 : 0.5482062697410583, 0.0 
this is loss for epoch 3340 : 0.5478553175926208, 0.0 
this is loss for epoch 3341 : 0.5475233197212219, 0.0 
this is loss for epoch 3342 : 0.5471871495246887, 0.0 
this is loss for epoch 3343 : 0.546790599822998, 0.0 
this is loss for epoch 3344 : 0.546302318572998, 0.0 
this is loss for epoch 3345 : 0.5459414720535278, 0.0 
this is loss for epoch 3346 : 0.5455610752105713, 0.0 
this is loss for epoch 3347 : 0.5451468229293823, 0.0 
this is loss for epoch 3348 : 0.5448228716850281, 0.0 
this is loss for epoch 3349 : 0.5444735884666443, 0.0 
this is loss for epoch 3350 : 0.5441825985908508, 0.0 
this is loss for epoch 3351 : 0.5438217520713806, 0.0 
this is loss for epoch 3352 : 0.5434583425521851, 0.0 
this is loss for epoch 3353 : 0.5430768132209778, 0.0 
this is loss 

this is loss for epoch 3486 : 0.5007396340370178, 0.0 
this is loss for epoch 3487 : 0.5001250505447388, 0.0 
this is loss for epoch 3488 : 0.4996884763240814, 0.0 
this is loss for epoch 3489 : 0.4993092119693756, 0.0 
this is loss for epoch 3490 : 0.49887529015541077, 0.0 
this is loss for epoch 3491 : 0.49847450852394104, 0.0 
this is loss for epoch 3492 : 0.4982523024082184, 0.0 
this is loss for epoch 3493 : 0.498080313205719, 0.0 
this is loss for epoch 3494 : 0.4976452887058258, 0.0 
this is loss for epoch 3495 : 0.4971117079257965, 0.0 
this is loss for epoch 3496 : 0.49652695655822754, 0.0 
this is loss for epoch 3497 : 0.4957299828529358, 0.0 
this is loss for epoch 3498 : 0.4947231709957123, 0.0 
this is loss for epoch 3499 : 0.49392762780189514, 0.0 
this is loss for epoch 3500 : 0.49396106600761414, 0.0 
this is loss for epoch 3501 : 0.49359560012817383, 0.0 
this is loss for epoch 3502 : 0.49328213930130005, 0.0 
this is loss for epoch 3503 : 0.4927336871623993, 0.0 
this

this is loss for epoch 3635 : 0.45275112986564636, 0.0 
this is loss for epoch 3636 : 0.45268067717552185, 0.0 
this is loss for epoch 3637 : 0.452738493680954, 0.0 
this is loss for epoch 3638 : 0.45286574959754944, 0.0 
this is loss for epoch 3639 : 0.45296475291252136, 0.0 
this is loss for epoch 3640 : 0.4530223309993744, 0.0 
this is loss for epoch 3641 : 0.4531216621398926, 0.0 
this is loss for epoch 3642 : 0.4533993899822235, 0.0 
this is loss for epoch 3643 : 0.45380422472953796, 0.0 
this is loss for epoch 3644 : 0.4541918933391571, 0.0 
this is loss for epoch 3645 : 0.45437678694725037, 0.0 
this is loss for epoch 3646 : 0.4543161392211914, 0.0 
this is loss for epoch 3647 : 0.45406460762023926, 0.0 
this is loss for epoch 3648 : 0.4535748064517975, 0.0 
this is loss for epoch 3649 : 0.45272284746170044, 0.0 
this is loss for epoch 3650 : 0.45148858428001404, 0.0 
this is loss for epoch 3651 : 0.45010486245155334, 0.0 
this is loss for epoch 3652 : 0.4488878846168518, 0.0 
t

this is loss for epoch 3783 : 0.41501593589782715, 0.0 
this is loss for epoch 3784 : 0.4141010344028473, 0.0 
this is loss for epoch 3785 : 0.4136790931224823, 0.0 
this is loss for epoch 3786 : 0.41354313492774963, 0.0 
this is loss for epoch 3787 : 0.413358598947525, 0.0 
this is loss for epoch 3788 : 0.41296082735061646, 0.0 
this is loss for epoch 3789 : 0.4122171998023987, 0.0 
this is loss for epoch 3790 : 0.4112437665462494, 0.0 
this is loss for epoch 3791 : 0.4102631211280823, 0.0 
this is loss for epoch 3792 : 0.4095345437526703, 0.0 
this is loss for epoch 3793 : 0.4090414047241211, 0.0 
this is loss for epoch 3794 : 0.4086821675300598, 0.0 
this is loss for epoch 3795 : 0.40844616293907166, 0.0 
this is loss for epoch 3796 : 0.4082920253276825, 0.0 
this is loss for epoch 3797 : 0.4080500602722168, 0.0 
this is loss for epoch 3798 : 0.4076438844203949, 0.0 
this is loss for epoch 3799 : 0.40717679262161255, 0.0 
this is loss for epoch 3800 : 0.4067046344280243, 0.0 
this i

this is loss for epoch 3931 : 0.3730491101741791, 0.0 
this is loss for epoch 3932 : 0.37298503518104553, 0.0 
this is loss for epoch 3933 : 0.3729395866394043, 0.0 
this is loss for epoch 3934 : 0.372936874628067, 0.0 
this is loss for epoch 3935 : 0.3729976713657379, 0.0 
this is loss for epoch 3936 : 0.37308716773986816, 0.0 
this is loss for epoch 3937 : 0.3731819689273834, 0.0 
this is loss for epoch 3938 : 0.3732563555240631, 0.0 
this is loss for epoch 3939 : 0.3732522428035736, 0.0 
this is loss for epoch 3940 : 0.3731224536895752, 0.0 
this is loss for epoch 3941 : 0.3728022277355194, 0.0 
this is loss for epoch 3942 : 0.372224360704422, 0.0 
this is loss for epoch 3943 : 0.3715362250804901, 0.0 
this is loss for epoch 3944 : 0.3707650899887085, 0.0 
this is loss for epoch 3945 : 0.36991703510284424, 0.0 
this is loss for epoch 3946 : 0.3690924346446991, 0.0 
this is loss for epoch 3947 : 0.3684239089488983, 0.0 
this is loss for epoch 3948 : 0.36787980794906616, 0.0 
this is 

this is loss for epoch 4080 : 0.3382294476032257, 0.0 
this is loss for epoch 4081 : 0.3380682170391083, 0.0 
this is loss for epoch 4082 : 0.33790862560272217, 0.0 
this is loss for epoch 4083 : 0.33769887685775757, 0.0 
this is loss for epoch 4084 : 0.33740121126174927, 0.0 
this is loss for epoch 4085 : 0.3370712399482727, 0.0 
this is loss for epoch 4086 : 0.33677220344543457, 0.0 
this is loss for epoch 4087 : 0.33651217818260193, 0.0 
this is loss for epoch 4088 : 0.3362867534160614, 0.0 
this is loss for epoch 4089 : 0.3361043334007263, 0.0 
this is loss for epoch 4090 : 0.3359481692314148, 0.0 
this is loss for epoch 4091 : 0.3357888162136078, 0.0 
this is loss for epoch 4092 : 0.33560553193092346, 0.0 
this is loss for epoch 4093 : 0.3353915810585022, 0.0 
this is loss for epoch 4094 : 0.33521029353141785, 0.0 
this is loss for epoch 4095 : 0.3349769711494446, 0.0 
this is loss for epoch 4096 : 0.33473294973373413, 0.0 
this is loss for epoch 4097 : 0.3344564139842987, 0.0 
th

this is loss for epoch 4228 : 0.3071788549423218, 0.0 
this is loss for epoch 4229 : 0.3068884313106537, 0.0 
this is loss for epoch 4230 : 0.3067445158958435, 0.0 
this is loss for epoch 4231 : 0.3066309988498688, 0.0 
this is loss for epoch 4232 : 0.3065321743488312, 0.0 
this is loss for epoch 4233 : 0.30638453364372253, 0.0 
this is loss for epoch 4234 : 0.3062170147895813, 0.0 
this is loss for epoch 4235 : 0.3060021996498108, 0.0 
this is loss for epoch 4236 : 0.3057599365711212, 0.0 
this is loss for epoch 4237 : 0.30547115206718445, 0.0 
this is loss for epoch 4238 : 0.30516502261161804, 0.0 
this is loss for epoch 4239 : 0.30485233664512634, 0.0 
this is loss for epoch 4240 : 0.30463218688964844, 0.0 
this is loss for epoch 4241 : 0.3044233024120331, 0.0 
this is loss for epoch 4242 : 0.3044055998325348, 0.0 
this is loss for epoch 4243 : 0.30443254113197327, 0.0 
this is loss for epoch 4244 : 0.30451279878616333, 0.0 
this is loss for epoch 4245 : 0.30471205711364746, 0.0 
th

this is loss for epoch 4376 : 0.28123709559440613, 0.0 
this is loss for epoch 4377 : 0.2811345160007477, 0.0 
this is loss for epoch 4378 : 0.28103137016296387, 0.0 
this is loss for epoch 4379 : 0.2809831202030182, 0.0 
this is loss for epoch 4380 : 0.2809954881668091, 0.0 
this is loss for epoch 4381 : 0.2810244858264923, 0.0 
this is loss for epoch 4382 : 0.28111669421195984, 0.0 
this is loss for epoch 4383 : 0.2813509702682495, 0.0 
this is loss for epoch 4384 : 0.2817417085170746, 0.0 
this is loss for epoch 4385 : 0.28227391839027405, 0.0 
this is loss for epoch 4386 : 0.2829296886920929, 0.0 
this is loss for epoch 4387 : 0.28369012475013733, 0.0 
this is loss for epoch 4388 : 0.2845281660556793, 0.0 
this is loss for epoch 4389 : 0.28536006808280945, 0.0 
this is loss for epoch 4390 : 0.2860739827156067, 0.0 
this is loss for epoch 4391 : 0.286448210477829, 0.0 
this is loss for epoch 4392 : 0.2863399386405945, 0.0 
this is loss for epoch 4393 : 0.28558772802352905, 0.0 
this

this is loss for epoch 4525 : 0.2582669258117676, 0.0 
this is loss for epoch 4526 : 0.2581345736980438, 0.0 
this is loss for epoch 4527 : 0.25799885392189026, 0.0 
this is loss for epoch 4528 : 0.2578631043434143, 0.0 
this is loss for epoch 4529 : 0.25773245096206665, 0.0 
this is loss for epoch 4530 : 0.25761398673057556, 0.0 
this is loss for epoch 4531 : 0.25750666856765747, 0.0 
this is loss for epoch 4532 : 0.25740599632263184, 0.0 
this is loss for epoch 4533 : 0.2573211193084717, 0.0 
this is loss for epoch 4534 : 0.25726163387298584, 0.0 
this is loss for epoch 4535 : 0.2571963667869568, 0.0 
this is loss for epoch 4536 : 0.25711703300476074, 0.0 
this is loss for epoch 4537 : 0.2570224106311798, 0.0 
this is loss for epoch 4538 : 0.2569564878940582, 0.0 
this is loss for epoch 4539 : 0.2569265067577362, 0.0 
this is loss for epoch 4540 : 0.2569120526313782, 0.0 
this is loss for epoch 4541 : 0.256902813911438, 0.0 
this is loss for epoch 4542 : 0.25696641206741333, 0.0 
thi

this is loss for epoch 4673 : 0.24013614654541016, 0.0 
this is loss for epoch 4674 : 0.239223912358284, 0.0 
this is loss for epoch 4675 : 0.23964880406856537, 0.0 
this is loss for epoch 4676 : 0.2406909018754959, 0.0 
this is loss for epoch 4677 : 0.24175655841827393, 0.0 
this is loss for epoch 4678 : 0.24236014485359192, 0.0 
this is loss for epoch 4679 : 0.2420673370361328, 0.0 
this is loss for epoch 4680 : 0.2409912645816803, 0.0 
this is loss for epoch 4681 : 0.23964665830135345, 0.0 
this is loss for epoch 4682 : 0.23852983117103577, 0.0 
this is loss for epoch 4683 : 0.23780818283557892, 0.0 
this is loss for epoch 4684 : 0.2375786155462265, 0.0 
this is loss for epoch 4685 : 0.23777882754802704, 0.0 
this is loss for epoch 4686 : 0.23811717331409454, 0.0 
this is loss for epoch 4687 : 0.2382904291152954, 0.0 
this is loss for epoch 4688 : 0.2381795197725296, 0.0 
this is loss for epoch 4689 : 0.23785507678985596, 0.0 
this is loss for epoch 4690 : 0.23737946152687073, 0.0 


this is loss for epoch 4821 : 0.22105562686920166, 0.0 
this is loss for epoch 4822 : 0.22099943459033966, 0.0 
this is loss for epoch 4823 : 0.22093163430690765, 0.0 
this is loss for epoch 4824 : 0.2208646535873413, 0.0 
this is loss for epoch 4825 : 0.22082820534706116, 0.0 
this is loss for epoch 4826 : 0.2208152562379837, 0.0 
this is loss for epoch 4827 : 0.22081677615642548, 0.0 
this is loss for epoch 4828 : 0.22085368633270264, 0.0 
this is loss for epoch 4829 : 0.2209378331899643, 0.0 
this is loss for epoch 4830 : 0.2210317701101303, 0.0 
this is loss for epoch 4831 : 0.22113148868083954, 0.0 
this is loss for epoch 4832 : 0.22126613557338715, 0.0 
this is loss for epoch 4833 : 0.22148175537586212, 0.0 
this is loss for epoch 4834 : 0.22176285088062286, 0.0 
this is loss for epoch 4835 : 0.22214365005493164, 0.0 
this is loss for epoch 4836 : 0.22252492606639862, 0.0 
this is loss for epoch 4837 : 0.2229338139295578, 0.0 
this is loss for epoch 4838 : 0.2232195883989334, 0.0

this is loss for epoch 4969 : 0.20619097352027893, 0.0 
this is loss for epoch 4970 : 0.20602701604366302, 0.0 
this is loss for epoch 4971 : 0.20616531372070312, 0.0 
this is loss for epoch 4972 : 0.20624351501464844, 0.0 
this is loss for epoch 4973 : 0.20619861781597137, 0.0 
this is loss for epoch 4974 : 0.20640364289283752, 0.0 
this is loss for epoch 4975 : 0.20667408406734467, 0.0 
this is loss for epoch 4976 : 0.2068798542022705, 0.0 
this is loss for epoch 4977 : 0.20729923248291016, 0.0 
this is loss for epoch 4978 : 0.20777960121631622, 0.0 
this is loss for epoch 4979 : 0.20823080837726593, 0.0 
this is loss for epoch 4980 : 0.20884791016578674, 0.0 
this is loss for epoch 4981 : 0.20941030979156494, 0.0 
this is loss for epoch 4982 : 0.20991602540016174, 0.0 
this is loss for epoch 4983 : 0.21034982800483704, 0.0 
this is loss for epoch 4984 : 0.21052025258541107, 0.0 
this is loss for epoch 4985 : 0.21039700508117676, 0.0 
this is loss for epoch 4986 : 0.2099517434835434,

this is loss for epoch 5117 : 0.19196335971355438, 0.0 
this is loss for epoch 5118 : 0.19299927353858948, 0.0 
this is loss for epoch 5119 : 0.19446128606796265, 0.0 
this is loss for epoch 5120 : 0.19505201280117035, 0.0 
this is loss for epoch 5121 : 0.19430720806121826, 0.0 
this is loss for epoch 5122 : 0.19274453818798065, 0.0 
this is loss for epoch 5123 : 0.19130365550518036, 0.0 
this is loss for epoch 5124 : 0.19061143696308136, 0.0 
this is loss for epoch 5125 : 0.190757617354393, 0.0 
this is loss for epoch 5126 : 0.19137214124202728, 0.0 
this is loss for epoch 5127 : 0.19190572202205658, 0.0 
this is loss for epoch 5128 : 0.19193752110004425, 0.0 
this is loss for epoch 5129 : 0.19136443734169006, 0.0 
this is loss for epoch 5130 : 0.19035834074020386, 0.0 
this is loss for epoch 5131 : 0.18966440856456757, 0.0 
this is loss for epoch 5132 : 0.18965663015842438, 0.0 
this is loss for epoch 5133 : 0.18985429406166077, 0.0 
this is loss for epoch 5134 : 0.18990498781204224,

this is loss for epoch 5265 : 0.1775362491607666, 0.0 
this is loss for epoch 5266 : 0.17724238336086273, 0.0 
this is loss for epoch 5267 : 0.17704857885837555, 0.0 
this is loss for epoch 5268 : 0.17695216834545135, 0.0 
this is loss for epoch 5269 : 0.17693106830120087, 0.0 
this is loss for epoch 5270 : 0.17694735527038574, 0.0 
this is loss for epoch 5271 : 0.17696422338485718, 0.0 
this is loss for epoch 5272 : 0.1769586056470871, 0.0 
this is loss for epoch 5273 : 0.1769271194934845, 0.0 
this is loss for epoch 5274 : 0.17687687277793884, 0.0 
this is loss for epoch 5275 : 0.17681443691253662, 0.0 
this is loss for epoch 5276 : 0.17673994600772858, 0.0 
this is loss for epoch 5277 : 0.1766493022441864, 0.0 
this is loss for epoch 5278 : 0.17654214799404144, 0.0 
this is loss for epoch 5279 : 0.17642365396022797, 0.0 
this is loss for epoch 5280 : 0.1763039082288742, 0.0 
this is loss for epoch 5281 : 0.17619089782238007, 0.0 
this is loss for epoch 5282 : 0.1760888546705246, 0.0

this is loss for epoch 5413 : 0.179814875125885, 0.0 
this is loss for epoch 5414 : 0.1746210902929306, 0.0 
this is loss for epoch 5415 : 0.1689480096101761, 0.0 
this is loss for epoch 5416 : 0.16649405658245087, 0.0 
this is loss for epoch 5417 : 0.16778935492038727, 0.0 
this is loss for epoch 5418 : 0.1705242395401001, 0.0 
this is loss for epoch 5419 : 0.1719031184911728, 0.0 
this is loss for epoch 5420 : 0.17070718109607697, 0.0 
this is loss for epoch 5421 : 0.16806276142597198, 0.0 
this is loss for epoch 5422 : 0.16599102318286896, 0.0 
this is loss for epoch 5423 : 0.16568289697170258, 0.0 
this is loss for epoch 5424 : 0.166696235537529, 0.0 
this is loss for epoch 5425 : 0.16764666140079498, 0.0 
this is loss for epoch 5426 : 0.1675303876399994, 0.0 
this is loss for epoch 5427 : 0.16642698645591736, 0.0 
this is loss for epoch 5428 : 0.16518737375736237, 0.0 
this is loss for epoch 5429 : 0.16458198428153992, 0.0 
this is loss for epoch 5430 : 0.16471199691295624, 0.0 
t

this is loss for epoch 5561 : 0.15936200320720673, 0.0 
this is loss for epoch 5562 : 0.15642622113227844, 0.0 
this is loss for epoch 5563 : 0.15505102276802063, 0.0 
this is loss for epoch 5564 : 0.15517637133598328, 0.0 
this is loss for epoch 5565 : 0.15629112720489502, 0.0 
this is loss for epoch 5566 : 0.1576739400625229, 0.0 
this is loss for epoch 5567 : 0.158675417304039, 0.0 
this is loss for epoch 5568 : 0.15884384512901306, 0.0 
this is loss for epoch 5569 : 0.15816648304462433, 0.0 
this is loss for epoch 5570 : 0.1569076031446457, 0.0 
this is loss for epoch 5571 : 0.15554200112819672, 0.0 
this is loss for epoch 5572 : 0.15445521473884583, 0.0 
this is loss for epoch 5573 : 0.15385417640209198, 0.0 
this is loss for epoch 5574 : 0.15373791754245758, 0.0 
this is loss for epoch 5575 : 0.15396051108837128, 0.0 
this is loss for epoch 5576 : 0.15431438386440277, 0.0 
this is loss for epoch 5577 : 0.15459579229354858, 0.0 
this is loss for epoch 5578 : 0.15467020869255066, 0

KeyboardInterrupt: 

In [None]:
if __name__ == '__main__':
    seed = 2000
    random.seed(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)

    points_info,normals_info,face_info = read_obj('frame_000001.obj')
    old_ones = np.ones((points_info.shape[0],1))

    points_info = np.concatenate((points_info,old_ones),axis=1)

    old_info = points_info
    
    all_info_list = []
    for i in range(80,90):
        path = "frame_0000{}.obj".format(i+1)

        tmp_target = read_obj(path)[0]
        target_ones = np.ones((tmp_target.shape[0],1))
        tmp_target_new = np.concatenate((tmp_target,target_ones),axis=1)
        all_info_list.append(tmp_target_new)

    all_info = np.stack(all_info_list)
    old_info = points_info
    target_info = all_info
    # W,R= train(old_info=points_info,target_info=all_info)
#     bp()
    N = points_info.shape[0]
    F = target_info.shape[0]
    B = 40  ### 这个肯定可以通过读pkl得到

    center, indexes = kmeans_plusplus(old_info[:, :-1], n_clusters=B, random_state=0)
    old_mesh = torch.tensor(old_info).float().cuda().detach()
    # bp()
    # ones = torch.ones(old_mesh.shape[0], 1).cuda()
    # x = torch.cat([old_mesh, ones], dim=1)

    center_t = torch.tensor(center).repeat(F, 1, 1).cuda()
    target_mesh = torch.tensor(target_info, requires_grad=False).float().cuda().detach()

    
    
    p1 = torch.randn((F, B, 7), requires_grad=True, device="cuda")

#     W1 = FF.softmax(W * 10).clone().detach()
#     W1 = (W1 / (W1.sum(1, keepdim=True).detach()))

    optimizer = optim.Adam([p1],lr=1e-2)
    chamLoss = chamfer3D.dist_chamfer_3D.chamfer_3DDist()
#     W = torch.tensor(W)
    
#     W1 = np.load('weight.npy')
    W1 = torch.tensor(W1).cuda()
    la_loss = LaplacianLoss(torch.tensor(points_info),torch.tensor(face_info-1),average=True)
    
    for i in range(10000):
        if i == 400:
            optimizer = optim.Adam([p1],lr=1e-3)
        optimizer.zero_grad()
        loss = 0
        f_idxs = random.sample(range(F), F)
        
        for f in f_idxs:
            T = torch.randn(B, 4, 4)

            for b in range(B):
                # R = SO3.InitFromVec(p1[f, b])
                # bp()
                # bp()
                # with torch.no_grad():
                # temp = torch.cat([p1[f, b], dist[f][:-1]])
                R = SE3.InitFromVec(p1[f, b])
                # bp()
                T[b] = R.matrix()
            T = T.cuda()     
        
            x = old_mesh
        
        
            bx = (T@x.T).permute(2, 0, 1)
            wbx = W1.unsqueeze(2) * bx
            # bp()
            wbx =  wbx.permute((1,0,2))40
            # print("Wbx", wbx.shape)
            # wbx = wbx.sum(0, keepdim=True)[:, :, :-1]
            wbx = wbx.sum(0, keepdim=True)
            # loss_init = chamLoss(target_mesh[f][None],wbx)
            # loss += loss_init[0].sum()
#             bp()
            loss += torch.sum((target_mesh[f]-wbx)**2) 
#             print(loss)
        # bp()
        loss.backward()
        # print("this is loss for epoch {} : {} ".format(i,loss.detach().cpu().numpy()))
        print("this is loss for epoch {} : {} ".format(i,loss / F))
        # print("p1: ", p1)
        optimizer.step()


    new_R = SE3.InitFromVec(p1)
    R = new_R.matrix()

    # W1 = (W - W.min()) / (W.max() - W.min())
    # W1 = (W1 / (W1.sum(1, keepdim=True).detach()))
    # bp()
    # return R.detach(), W1.detach()

In [30]:
la_loss = LaplacianLoss(points_info,points_face-1,average=True)

In [31]:
la_loss(torch.tensor(wbx_result[-1][:, :-1]))

tensor(0.0052)

In [28]:
ccc = torch.tensor([[1,2,3],[4,5,6]])
ccc.size(0)

2

In [29]:
points_info.shape

(3199, 4)

In [38]:
points_info,points_normal,points_face = read_obj('frame_000001.obj')
# all_info_list = []

# for i in range(1, 9):
#     path = "frame_00000{}.obj".format(i+1)
#     all_info_list.append(read_obj(path))

In [42]:
points_info.shape

(711, 3)

In [222]:
points_face.shape

(8726,)

In [6]:
R2 = R[:, :, :-1,:-1]
R2.shape

(8, 45, 3, 3)

In [6]:
import torch
W = torch.tensor(W).cuda()

  


In [93]:
old_mesh = torch.tensor(points_info).float().cuda()
# W1 = (W - W.min()) / (W.max() - W.min())
W1 = FF.softmax(W)
W1 = (W1 / (W1.sum(1, keepdim=True).detach()))
W1.shape

  This is separate from the ipykernel package so we can avoid doing imports until


torch.Size([8315, 45])

In [17]:
R = torch.tensor(R).cuda()

In [11]:
(W1.unsqueeze(2) * (R[0]@old_mesh.T).permute(2, 0, 1)).permute(1, 0, 2).sum(0).shape

TypeError: unsupported operand type(s) for @: 'numpy.ndarray' and 'Tensor'

In [235]:
new_R = SE3.InitFromVec(p1)
R = new_R.matrix()

In [8]:
R1 = R.matrix()

In [14]:
R.shape

torch.Size([8, 30, 4, 4])

In [81]:
new_R = SE3.InitFromVec(p1)
R = new_R.matrix()
wbx_result = []
for r in R:
    bx = (r@old_mesh.T).permute(2, 0, 1)
    # W = W.unsqueeze(1)
    wbx = W1.unsqueeze(2) * bx

    wbx =  wbx.permute((1,0,2))
    wbx = wbx.sum(0).detach().cpu().numpy()
    wbx_result.append(wbx)

In [77]:
wbx_result0 = wbx_result

In [57]:
wbx = wbx_result.detach().cpu().numpy()[:, :-1]
wbx.shape

AttributeError: 'list' object has no attribute 'detach'

In [128]:
W3.unsqueeze(2) * old_mesh.unsqueeze(1).repeat(1, 45, 1)for wbx_index,single_wbx in enumerate(wbx_result):
    # single_wbx = single_wbx.detach().cpu().numpy()
    print("this is index ",wbx_index)
    print(np.sum((all_info[wbx_index]-single_wbx)**2))
    # for gt in all_info:
    #     print(np.sum((gt-single_wbx)**2))

this is index  0
21.36908918353724
this is index  1
128.30545293720138
this is index  2
373.0977968274271
this is index  3
771.560108692853
this is index  4
1293.8557314639206
this is index  5
1870.6729752832866
this is index  6
2412.4298995824365
this is index  7
2813.369257814287


In [263]:
W1 = W1.detach().cpu().numpy()

In [264]:
W1 = torch.tensor(W1)

In [63]:
color_list = []
number_list = [0,0.2,0.4,0.6,0.8,1]
for i in range(B):
    color = random.sample(number_list,3)
    color_list.append(color)

In [64]:
color_array = np.array(color_list)

In [7]:
color_array = np.array([[0.8, 0.4, 0.2],
       [0.6, 0.8, 1. ],
       [0.6, 0.2, 0.4],
       [1. , 0.6, 0.8],
       [0.2, 0.4, 0. ],
       [0.2, 0.4, 0.8],
       [0.4, 1. , 0.6],
       [1. , 0.4, 0. ],
       [0.6, 0.2, 0. ],
       [0.2, 1. , 0.4],
       [0.6, 0.2, 0.4],
       [0.6, 0.4, 0. ],
       [0.4, 0.6, 0.8],
       [0.8, 0. , 1. ],
       [0.4, 0.6, 0.2],
       [0.6, 1. , 0.2],
       [0. , 0.2, 0.4],
       [0.2, 0.8, 1. ],
       [0.4, 0.2, 1. ],
       [0.8, 0.4, 1. ],
       [0.8, 1. , 0. ],
       [1. , 0. , 0.6],
       [0.2, 1. , 0. ],
       [0.6, 0.2, 0.8],
       [1. , 0.6, 0.4],
       [0.4, 0. , 1. ],
       [0.2, 1. , 0.4],
       [0.2, 0.4, 0.8],
       [1. , 0. , 0.8],
       [1. , 0.8, 0.4]])

In [60]:
W_new = W1.clone()

In [28]:
W_new.sum()

tensor(3198.9993, device='cuda:0')

In [59]:
with torch.no_grad():
    for i in range(W1.shape[0]):
        max_idx = torch.topk(W1[i], 3)[1]
        for j in range(B):
            if j not in max_idx:
                W1[i, j] = 0
        W_new[i] = W1[i] / W1[i].sum()

NameError: name 'W_new' is not defined

In [61]:
W_c = W_new.detach().cpu().numpy()

In [13]:
(W_c@color_array).shape

ValueError: matmul: Input operand 1 has a mismatch in its core dimension 0, with gufunc signature (n?,k),(k,m?)->(n?,m?) (size 30 is different from 40)

In [39]:
pcd_test = o3d.geometry.PointCloud()
pcd_test.points = o3d.utility.Vector3dVector(points_info)
# pcd_test.colors = o3d.utility.Vector3dVector(W_c@color_array)
pcd_test.normals = o3d.utility.Vector3dVector(points_normal)


In [40]:
distances = pcd_test.compute_nearest_neighbor_distance()
avg_dist = np.mean(distances)
radius = avg_dist 

In [41]:
bpa_mesh = o3d.geometry.TriangleMesh.create_from_point_cloud_ball_pivoting(
pcd_test,o3d.utility.DoubleVector([radius, radius * 2]))

In [42]:
bpa_mesh.triangles = o3d.utility.Vector3iVector(points_face-1)

In [43]:
np.array(bpa_mesh.triangles).shape

(8804, 3)

In [44]:
o3d.visualization.draw_geometries([bpa_mesh])

In [89]:
len(wbx_result)

20

In [None]:
import open3d as o3d
pcd = o3d.geometry.PointCloud()
# pcd2 = o3d.geometry.PointCloud()
pcd.points = o3d.utility.Vector3dVector(wbx_result[-1][:, :-1])
# pcd.points = o3d.utility.Vector3dVector(points_info)
pcd.colors = o3d.utility.Vector3dVector(W_c@color_array)
# pcd.normals = o3d.utility.Vector3dVector(normals_info)
# pcd2.paint_uniform_color([0, 0, 1])
o3d.visualization.draw_geometries([pcd])


In [83]:
import open3d as o3d
pcd = o3d.geometry.PointCloud()
# pcd2 = o3d.geometry.PointCloud()
pcd.points = o3d.utility.Vector3dVector(points_info)
# pcd2.points = o3d.utility.Vector3dVector(eigvector_array)
pcd.colors = o3d.utility.Vector3dVector(W_c@color_array)
# pcd2.paint_uniform_color([0, 0, 1])
o3d.visualization.draw_geometries([pcd])


RuntimeError: 

In [19]:
W3 = W1[:, -2]
# W2 = W3.detach().cpu().numpy()

In [21]:
onet = torch.zeros((8315, 2)).cuda()
# ff = torch.randn((8315,2)).cuda()
W2 = torch.cat([W3.unsqueeze(1), onet], dim=1)
W2.shape

torch.Size([8315, 3])

In [22]:
W2 = W2.detach().cpu().numpy()

In [212]:
W2.shape

(8315,)

In [145]:
import openmesh

In [223]:
mesh11 = openmesh.PolyMesh()
vv = []
for i in points_info:
    vv.append(mesh11.add_vertex(i))

In [224]:
for j in points_face:

    if len(j) ==3:
        mesh11.add_face([vv[j[0]-1],vv[j[1]-1],vv[j[2]-1]])
    elif len(j) ==4:
        mesh11.add_face([vv[j[0]-1],vv[j[1]-1],vv[j[2]-1],vv[j[3]-1]])


In [226]:
openmesh.write_mesh(mesh11, "bunny1111.obj")

TypeError: write_mesh(): incompatible function arguments. The following argument types are supported:
    1. (filename: str, mesh: openmesh.TriMesh, binary: bool = False, msb: bool = False, lsb: bool = False, swap: bool = False, vertex_normal: bool = False, vertex_color: bool = False, vertex_tex_coord: bool = False, halfedge_tex_coord: bool = False, edge_color: bool = False, face_normal: bool = False, face_color: bool = False, color_alpha: bool = False, color_float: bool = False) -> None
    2. (filename: str, mesh: openmesh.PolyMesh, binary: bool = False, msb: bool = False, lsb: bool = False, swap: bool = False, vertex_normal: bool = False, vertex_color: bool = False, vertex_tex_coord: bool = False, halfedge_tex_coord: bool = False, edge_color: bool = False, face_normal: bool = False, face_color: bool = False, color_alpha: bool = False, color_float: bool = False) -> None

Invoked with: <openmesh.PolyMesh object at 0x7f2207309f30>, 'bunny1111.obj'

In [165]:
points_face.max()

8315

In [56]:
R[0, 0, :-1, :-1]@(R[0, 0, :-1, :-1].T)

tensor([[ 1.0000e+00, -1.5586e-08,  1.6673e-08],
        [-1.5586e-08,  1.0000e+00, -2.3843e-08],
        [ 1.6673e-08, -2.3843e-08,  1.0000e+00]], device='cuda:0',
       grad_fn=<MmBackward>)

In [86]:
np.vstack([wbx_result[3][:, :-1], Bone_location[3]])

array([[ 0.60822886,  2.6521692 ,  0.7332895 ],
       [ 0.63643765,  2.6594446 ,  0.63895345],
       [ 0.66882116,  2.567572  ,  0.6456163 ],
       ...,
       [-0.85560316, -1.5839747 , -0.34642002],
       [-0.56195986,  1.6341548 ,  1.198669  ],
       [ 0.34078336, -0.55604595, -0.81593436]], dtype=float32)

In [76]:
Bone_location[3].shape

(45, 3)

In [78]:
wbx_result[3][:, :-1].shape

(8315, 3)

In [88]:
X = np.array([[1, 2], [1, 4], [1, 0],
...               [10, 2], [10, 4], [10, 0]])

In [89]:
X.shape

(6, 2)

In [118]:
center, indexes = kmeans_plusplus(old_mesh[:, :-1].detach().cpu().numpy(), n_clusters=10, random_state=0)

In [98]:
t

(45, 3)

In [104]:
np.save("R.npy", R.detach().cpu().numpy())

In [105]:
np.save("W.npy", W.detach().cpu().numpy())

In [9]:
W = np.load('W.npy')

In [10]:
R = np.load('R.npy')

In [24]:
R[0]

tensor([[[ 6.6551e-01,  5.5304e-01, -5.0125e-01,  7.7804e-01],
         [-4.2060e-01,  8.3265e-01,  3.6026e-01, -3.1423e-01],
         [ 6.1661e-01, -2.8933e-02,  7.8674e-01,  3.8223e-01],
         [ 0.0000e+00,  0.0000e+00,  0.0000e+00,  1.0000e+00]],

        [[ 9.9372e-01, -1.0652e-01, -3.4407e-02,  1.7831e+00],
         [ 1.1189e-01,  9.5393e-01,  2.7840e-01,  2.1400e-01],
         [ 3.1677e-03, -2.8050e-01,  9.5985e-01, -5.5811e-01],
         [ 0.0000e+00,  0.0000e+00,  0.0000e+00,  1.0000e+00]],

        [[-5.4936e-02,  3.2795e-01, -9.4310e-01,  5.8502e-01],
         [-4.7454e-01,  8.2246e-01,  3.1365e-01, -9.7908e-01],
         [ 8.7852e-01,  4.6477e-01,  1.1044e-01,  2.0641e+00],
         [ 0.0000e+00,  0.0000e+00,  0.0000e+00,  1.0000e+00]],

        [[ 5.6366e-01,  6.3519e-01, -5.2803e-01, -5.1763e-01],
         [-8.2434e-01,  4.7319e-01, -3.1073e-01,  6.6134e-01],
         [ 5.2489e-02,  6.1043e-01,  7.9033e-01, -1.2407e-01],
         [ 0.0000e+00,  0.0000e+00,  0.0000e+00, 

In [5]:
R_tmp = R[0]

In [7]:
eigvector_list = []
for i in range(R_tmp.shape[0]):
    (evals,evecs) = torch.eig(torch.tensor(R_tmp[i]),eigenvectors=True)
#     print(i)
#     print(evals)
#     print(evecs)
    flag =  0
    for val_index in range(len(evals)):
        
        if torch.abs(evals[val_index][0]-1)<1e-5 and torch.abs(evals[val_index][1])<1e-5:

            scale_info = -100
            if torch.abs(evecs[val_index][0])+torch.abs(evecs[val_index][1])+torch.abs(evecs[val_index][2])>1e-4:
#                 print(evecs[val_index])
#                 vector_wanted = (evecs[val_index]/torch.abs(evecs[val_index][-1]))[:-1]
                vector_wanted = evecs[val_index][:-1]
#                 flag = 1
#                 if scale_info != -100:
#                     vector_wanted = vector_wanted/torch.abs(scale_info)
#                     eigvector_list.append(vector_wanted.cpu().detach().numpy())
#                     break
#             else:
#                 scale_info = evecs[val_index][-1]
#                 if flag==1:
#                     vector_wanted = vector_wanted/torch.abs(scale_info)
#                     eigvector_list.append(vector_wanted.cpu().detach().numpy())
#                     break
#                 eigvector_list.append(vector_wanted.cpu().detach().numpy())
#                 break

torch.linalg.eig returns complex tensors of dtype cfloat or cdouble rather than real tensors mimicking complex tensors.
L, _ = torch.eig(A)
should be replaced with
L_complex = torch.linalg.eigvals(A)
and
L, V = torch.eig(A, eigenvectors=True)
should be replaced with
L_complex, V_complex = torch.linalg.eig(A) (Triggered internally at  /opt/conda/conda-bld/pytorch_1623448224956/work/aten/src/ATen/native/BatchLinearAlgebra.cpp:2897.)
  This is separate from the ipykernel package so we can avoid doing imports until


In [147]:
eigvector_list = []
for i in range(R_tmp.shape[0]):

    (evals,evecs) = torch.linalg.eig(R_tmp[i])
    for j in range(4):
        if torch.abs(evecs[j].real.max()) > 1e-5 and torch.abs(evecs[j].imag.max()) < 1e-5:
            eigvector_list.append(evecs[j])
    # print(evals, evecs)

In [8]:
eigvector_list

[]

In [133]:
R_tmp[i].shape

torch.Size([4, 4])

In [99]:
eigvector_array = np.array(eigvector_list)

In [104]:
eigvector_array

array([[-2.08931750e+05, -9.44651125e+05,  1.12544150e+06],
       [-3.24681225e+06, -1.43762031e+05,  8.98924188e+05],
       [ 1.95427598e+04, -5.86003062e+05,  3.64042188e+05],
       [-5.96894375e+05, -1.18165962e+06,  2.50963225e+06],
       [-1.24324430e+05,  4.42932750e+06,  4.41323500e+06],
       [ 1.31386288e+06, -3.10073656e+05, -5.37888375e+05],
       [-1.26899762e+06,  9.39687400e+06,  7.66483850e+06],
       [-9.65893750e+05,  0.00000000e+00,  3.28692266e+04],
       [-2.63832363e+04, -5.67883438e+05,  6.76968875e+05],
       [ 4.71411600e+06,  0.00000000e+00,  1.49484412e+06],
       [ 9.45399625e+05, -1.65568438e+06,  2.87590175e+06],
       [-5.24790875e+05,  1.72209975e+06,  1.08762620e+07],
       [-8.25391094e+04,  3.47957688e+05,  2.88241406e+05],
       [-1.15535425e+06,  7.71927500e+06,  5.02589250e+06],
       [-9.30507812e+05, -2.15668350e+06,  3.12389625e+06],
       [-2.89633725e+06, -4.90623450e+06,  6.42465350e+06],
       [-8.08298650e+06, -9.63290750e+05

In [103]:
for k in range(eigvector_array.shape[0]):
    
    print(eigvector_array[k].max(),eigvector_array[k].min())

1125441.5 -944651.1
898924.2 -3246812.2
364042.2 -586003.06
2509632.2 -1181659.6
4429327.5 -124324.43
1313862.9 -537888.4
9396874.0 -1268997.6
32869.227 -965893.75
676968.9 -567883.44
4714116.0 0.0
2875901.8 -1655684.4
10876262.0 -524790.9
347957.7 -82539.11
7719275.0 -1155354.2
3123896.2 -2156683.5
6424653.5 -4906234.5
-963290.75 -8082986.5
8315559.5 29267.422
266241.4 -2776276.0
7286642.5 0.0
4245924.0 -517617.72
317961.75 -934659.5
9409278.0 -515592.22
502299.16 -2235584.8
5599194.0 1259836.9
889088.3 -4156421.2
1063102.4 -114584.5
6152423.5 -237797.12
-401868.62 -1901944.4
-84895.055 -1103657.4
458805.28 -213043.7
2136201.5 578177.0
1281208.0 -15912.146
1737261.4 -8719609.0
8410769.0 -2069724.2
10213309.0 -1847051.1
78522.86 0.0
198893.73 -49068.426
1405885.1 -602139.7
4323145.0 -213344.0
9354421.0 -651071.4
6311035.0 -325663.12
212339.86 -1727798.6
7500047.5 -1767407.0
335404.25 -224781.38


In [29]:
(evals,evecs) = torch.eig(R[0][1],eigenvectors=True)

In [30]:
evals

tensor([[ 1.0000,  0.0000],
        [ 0.9537,  0.3006],
        [ 0.9537, -0.3006],
        [ 1.0000,  0.0000]], device='cuda:0')

In [31]:
evecs

tensor([[-9.2959e-01, -4.1160e-02,  2.5737e-01,  9.2959e-01],
        [-6.2496e-02,  7.0572e-01,  0.0000e+00,  6.2496e-02],
        [ 3.6326e-01,  1.6084e-02,  6.5861e-01, -3.6326e-01],
        [ 0.0000e+00,  0.0000e+00,  0.0000e+00,  2.8631e-07]], device='cuda:0')

In [28]:
R[0][0]

tensor([[ 0.6655,  0.5530, -0.5012,  0.7780],
        [-0.4206,  0.8327,  0.3603, -0.3142],
        [ 0.6166, -0.0289,  0.7867,  0.3822],
        [ 0.0000,  0.0000,  0.0000,  1.0000]], device='cuda:0')

In [124]:
for j in range(45):
    R[0] 

torch.Size([8, 45, 4, 4])

In [None]:
def train(old_info,target_info, n_steps=1000, lr_init=0.2):
    """ Riemannian Gradient Descent """

    N = points_info.shape[0]
    F = target_info.shape[0]
    B = 45  ### 这个肯定可以通过读pkl得到


    old_mesh = torch.tensor(old_info).float().cuda().detach()
    
    # bp()
    # ones = torch.ones(old_mesh.shape[0], 1).cuda()
    # x = torch.cat([old_mesh, ones], dim=1)


    target_mesh = torch.tensor(target_info, requires_grad=False).float().cuda().detach()

    # bp()
    # W = torch.randint(10,(N,B))
    # W = torch.randn((N,B)).cuda()
    # W = np.random.randint(10,size=(N,B))
    # bp()
    # W = SO3(torch.from_numpy(W).float().cuda())
    # W = torch.zeros((N, B), dtype=torch.float, requires_grad=True, device="cuda")
    W = torch.randn((N, B), requires_grad=True, device="cuda")
    W2 = torch.randn((N, B), requires_grad=True, device="cuda")
    # W = torch.tensor(W0.float(), requires_grad=True, device="cuda")
    # W = W * 0
    # bp()
    # W = LieGroupParameter(W)
    # bp()
    # R = torch.randint(10,(B,3,3))laplacian = torch.zeros([self.nv, self.nv]).astype(np.float32)

    # bp()
    # random quaternion
    p1 = torch.randn((F, B, 7), requires_grad=True, device="cuda")
    # W = W.cuda()
    # p1 = p1.cuda()
    # p1 = p1 / p1.norm(dim=-1, keepdim=True)
    # create SO3 object from quaternion (differentiable w.r.t q)
    
    # 4x4 transformation matrix (differentiable w.r.t R)
    


    optimizer = optim.Adam([W, p1],lr=1e-3)
    chamLoss = chamfer3D.dist_chamfer_3D.chamfer_3DDist()
    for i in range(n_steps):

        optimizer.zero_grad()
        loss = 0
        f_idxs = random.sample(range(F), F)
        
        for f in f_idxs:
            T = torch.randn(B, 4, 4)
            for b in range(B):
                # R = SO3.InitFromVec(p1[f, b])
                R = SE3.InitFromVec(p1[f, b])
                # bp()
                T[b] = R.matrix()
            T = T.cuda()     
            # bp()      
            # bp() 
            # W1 = W
            # W1 = (W - W.min()) / (W.max() - W.min())
            W1 = FF.softmax(W*10)
            W1 = (W1 / (W1.sum(1, keepdim=True).detach()))
            W2[W1 < 0.1] = 0
            
            # ones = torch.ones(old_mesh.shape[0], 1).cuda()
            # x = torch.cat([old_mesh, ones], dim=1)
            x = old_mesh
        
        
            bx = (T@x.T).permute(2, 0, 1)
            wbx = W1.unsqueeze(2) * bx
            # bp()
            wbx =  wbx.permute((1,0,2))
            # print("Wbx", wbx.shape)
            # wbx = wbx.sum(0, keepdim=True)[:, :, :-1]
            wbx = wbx.sum(0, keepdim=True)
            # loss_init = chamLoss(target_mesh[f][None],wbx)
            # loss += loss_init[0].sum()

            loss += torch.sum((target_mesh[f]-wbx)**2) + torch.sum(torch.abs(W))

#             loss.backward()
            print("loss",loss)
            # bp()
        # bp()
        loss.backward()
        # print("this is loss for epoch {} : {} ".format(i,loss.detach().cpu().numpy()))
        print("this is loss for epoch {} : {} ".format(i,loss / F))
        # print("p1: ", p1)
        optimizer.step()


    new_R = SE3.InitFromVec(p1)
    R = new_R.matrix()

    # W1 = (W - W.min()) / (W.max() - W.min())
    # W1 = (W1 / (W1.sum(1, keepdim=True).detach()))
    # bp()
    # return R.detach(), W1.detach()
    return W.detach(),R.detach()