In [10]:
import os
os.environ["CUDA_VISIBLE_DEVICES"] = "2"
import torch
import torch.nn as nn
import torch.optim as optim
from network.gap_layers import *
from datasets.datasets_pair import *
import functools
from network.sym_v1 import *
from loss.utils import *
from network.utils import *
from network.PointTransformerv3 import *
from torch.utils.tensorboard import SummaryWriter
from tqdm import tqdm
from datetime import datetime
import spconv.pytorch as spconv

In [16]:
class test_PT3(nn.Module):
    def __init__(self):
        super().__init__()
        self.pt3 = PointTransformerV3()
        self.rot_green_head = nn.Linear(64, 3)
        self.rot_red_head = nn.Linear(64, 3)
        

    def forward(self, pc_pairs: List[PointCloudPair]):
        pc1s = [pc_pair.pc1 for pc_pair in pc_pairs]
        pc2s = [pc_pair.pc2 for pc_pair in pc_pairs]
        bs = len(pc_pairs)
        pc_batch_1: PointCloudBatch = PointCloud.collate(pc1s)
        pc_batch_2: PointCloudBatch = PointCloud.collate(pc2s)
        
        pc_voxel_id_1 = pc_batch_1.pc_voxel_id
        pc_voxel_id_2 = pc_batch_2.pc_voxel_id
        in_dict_1 = {
            "feat": pc_batch_1.voxel_tensor.features,
            "grid_coord": pc_batch_1.voxel_tensor.indices[:, 1:],
            "batch": pc_batch_1.voxel_tensor.indices[:, 0],
            "coord": pc_batch_1.voxel_tensor.indices[:, 1:] * 0.01,
            "grid_size": 0.01
        }
        in_dict_2 = {
            "feat": pc_batch_2.voxel_tensor.features,
            "grid_coord": pc_batch_2.voxel_tensor.indices[:, 1:],
            "batch": pc_batch_2.voxel_tensor.indices[:, 0],
            "coord": pc_batch_2.voxel_tensor.indices[:, 1:] * 0.01,
            "grid_size": 0.01
        }
        outs_1 = self.pt3(in_dict_1)
        outs_2 = self.pt3(in_dict_2)
        pc_feature_1 = outs_1['feat'][pc_voxel_id_1]
        pc_feature_2 = outs_2['feat'][pc_voxel_id_2]

        pc_feature_1 = pc_feature_1.view(bs, -1, 64) # bs,n,64
        pc_feature_2 = pc_feature_2.view(bs, -1, 64)

        rot_green_1 = self.rot_green_head(pc_feature_1).mean(dim=1).view(bs, 3) # bs,n,3
        rot_green_2 = self.rot_green_head(pc_feature_2).mean(dim=1).view(bs, 3)

        rot_red_1 = self.rot_red_head(pc_feature_1).mean(dim=1).view(bs, 3)
        rot_red_2 = self.rot_red_head(pc_feature_2).mean(dim=1).view(bs, 3) # bs,3
        
        return (rot_green_1, rot_green_2), (rot_red_1, rot_red_2)

In [12]:
# loss
class fs_net_loss_R(nn.Module):
    def __init__(self, loss_type="smoothl1"):
        super(fs_net_loss_R, self).__init__()
        if loss_type == 'l1':
            self.loss_func_t = nn.L1Loss()
            self.loss_func_s = nn.L1Loss()
            self.loss_func_Rot1 = nn.L1Loss()
            self.loss_func_Rot2 = nn.L1Loss()
            self.loss_func_r_con = nn.L1Loss()
            self.loss_func_Recon = nn.L1Loss()
        elif loss_type == 'smoothl1':   # same as MSE
            self.loss_func_t = nn.SmoothL1Loss(beta=0.5)
            self.loss_func_s = nn.SmoothL1Loss(beta=0.5)
            self.loss_func_Rot1 = nn.SmoothL1Loss(beta=0.5)
            self.loss_func_Rot2 = nn.SmoothL1Loss(beta=0.5)
            self.loss_func_r_con = nn.SmoothL1Loss(beta=0.5)
            self.loss_func_Recon = nn.SmoothL1Loss(beta=0.3)
        else:
            raise NotImplementedError

    def forward(self, pred_list, gt_list, sym):
        loss_list = {}

        self.rot_1_w = 1

        loss_list["Rot1"] = self.rot_1_w * self.cal_loss_Rot1(pred_list["Rot1"], gt_list["Rot1"])

        loss_list["Rot2"] = self.rot_1_w * self.cal_loss_Rot2(pred_list["Rot2"], gt_list["Rot2"], sym)

        # loss_list["Recon"] = self.recon_w * self.cal_loss_Recon(pred_list["Recon"], gt_list["Recon"])

        # loss_list["Tran"] = self.tran_w * self.cal_loss_Tran(pred_list["Tran"], gt_list["Tran"])
    
        # loss_list["Size"] = self.size_w * self.cal_loss_Size(pred_list["Size"], gt_list["Size"])

        return loss_list

    def cal_loss_Rot1(self, pred_v, gt_v):
        bs = pred_v.shape[0]
        res = torch.zeros([bs], dtype=torch.float32, device=pred_v.device)
        for i in range(bs):
            pred_v_now = pred_v[i, ...]
            gt_v_now = gt_v[i, ...]
            res[i] = self.loss_func_Rot1(pred_v_now, gt_v_now)
        res = torch.mean(res)
        return res

    def cal_loss_Rot2(self, pred_v, gt_v, sym):
        bs = pred_v.shape[0]
        res = 0.0
        valid = 0.0
        for i in range(bs):
            sym_now = sym[i, 0]
            if sym_now == 1:
                continue
            else:
                pred_v_now = pred_v[i, ...]
                gt_v_now = gt_v[i, ...]
                res += self.loss_func_Rot2(pred_v_now, gt_v_now)
                valid += 1.0
        if valid > 0.0:
            res = res / valid
        return res

    def cal_loss_Recon(self, pred_recon, gt_recon):
        return self.loss_func_Recon(pred_recon, gt_recon)

    def cal_loss_Tran(self, pred_trans, gt_trans):
        return self.loss_func_t(pred_trans, gt_trans)

    def cal_loss_Size(self, pred_size, gt_size):
        return self.loss_func_s(pred_size, gt_size)

In [13]:
root_dir = "/16T/zhangran/GAPartNet_re_rendered/train"
test_intra_dir = "/16T/zhangran/GAPartNet_re_rendered/test_intra"
test_inter_dir = "/16T/zhangran/GAPartNet_re_rendered/test_inter"
def get_datasets(root_dir, test_intra_dir, test_inter_dir, voxelization=False, shot=False):
    if shot:
        few_shot = True
        few_shot_num = 20
    else:
        few_shot = False
        few_shot_num = None

    dataset_train = GAPartNetPair(
        Path(root_dir) / "pth",
        Path(root_dir) / "meta",
        shuffle=True,
        max_points=2000,
        augmentation=True,
        voxelization=voxelization, 
        group_size=2,
        voxel_size=[0.01,0.01,0.01],
        few_shot=few_shot,
        few_shot_num=few_shot_num,
        pos_jitter=0.1,
        with_pose=True,
        color_jitter=0.3,
        flip_prob=0.3,
        rotate_prob=0.3,
    )

    dataset_test_intra = GAPartNetPair(
        Path(test_intra_dir) / "pth",
        Path(test_intra_dir) / "meta",
        shuffle=False,
        max_points=2000,
        augmentation=True,
        voxelization=voxelization, 
        group_size=2,
        voxel_size=[0.01,0.01,0.01],
        few_shot=few_shot,
        few_shot_num=few_shot_num,
        pos_jitter=0.1,
        with_pose=True,
        color_jitter=0.3,
        flip_prob=0.3,
        rotate_prob=0.3,
    )

    dataset_test_inter = GAPartNetPair(
        Path(test_inter_dir) / "pth",
        Path(test_inter_dir) / "meta",
        shuffle=False,
        max_points=2000,
        augmentation=True,
        voxelization=voxelization, 
        group_size=2,
        voxel_size=[0.01,0.01,0.01],
        few_shot=few_shot,
        few_shot_num=few_shot_num,
        pos_jitter=0.1,
        with_pose=True,
        color_jitter=0.3,
        flip_prob=0.3,
        rotate_prob=0.3,
    )

    return dataset_train, dataset_test_intra, dataset_test_inter

def get_dataloaders(dataset_train, dataset_test_intra, dataset_test_inter, batch_size=16, num_workers=8):
    dataloader_train = DataLoader(
        dataset_train,
        batch_size=batch_size,
        shuffle=False,
        num_workers=num_workers,
        collate_fn=data_utils.trivial_batch_collator,
        pin_memory=True,
        drop_last=False
    )
    # test_intra_sampler = DistributedSampler(dataset_test_intra, shuffle=False)
    dataloader_test_intra = DataLoader(
        dataset_test_intra,
        batch_size=batch_size,
        shuffle=False,
        num_workers=num_workers,
        collate_fn=data_utils.trivial_batch_collator,
        pin_memory=True,
        drop_last=False,
        # sampler=test_intra_sampler
    )
    # test_inter_sampler = DistributedSampler(dataset_test_inter, shuffle=False)
    dataloader_test_inter = DataLoader(
        dataset_test_inter,
        batch_size=batch_size,
        shuffle=False,
        num_workers=num_workers,
        collate_fn=data_utils.trivial_batch_collator,
        pin_memory=True,
        drop_last=False,
        # sampler=test_inter_sampler
    )
    return dataloader_train, dataloader_test_intra, dataloader_test_inter

In [14]:
# Helper function to extract ground truth rotation vectors from the batch of PointCloudPairs
def ground_truth_rotations(rot_list: List[torch.Tensor]) -> np.ndarray:
    rotations = []
    for rot in rot_list:
        # Assuming the rotations are stored as 3x3 matrices in pc_pair.rot_1 and pc_pair.rot_2
        rotation_matrix = np.array(rot.cpu())  # Example using rot_1, adjust as needed
        rotations.append(rotation_matrix)
    return torch.tensor(np.stack(rotations))

def train(model: nn.Module, 
          dataloader_train: DataLoader, 
          dataloader_test_inter: DataLoader, 
          dataloader_test_intra: DataLoader, 
          lr: int = 0.001, 
          num_epochs: int=100, 
          log_dir: str=None, 
          device: torch.device=None):
    
    optimizer = optim.Adam(model.parameters(), lr=lr)
    criterion = fs_net_loss_R()
    if not device:
        device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
    model = model.to(device)
    model.train()
    assert log_dir is not None, "No Log Dir"
    log_dir = log_dir + "/" + str(datetime.today())
    writer = SummaryWriter(log_dir=log_dir)

    global_step = 0
    print("_________________________train_epoch___________________________")
    for epoch in range(num_epochs):
        total_loss = 0
        if epoch == 0:
            # first test epoch
            print("______________________first_test_epoch_________________________")
            torch.save(model.state_dict(), log_dir+r'/'+f"GPV_[{epoch+1}|{num_epochs}].pth")
            test_metrics(model, dataloader_test_inter, device, writer, epoch, 'test_inter')
            test_metrics(model, dataloader_test_intra, device, writer, epoch, 'test_intra')
        for batch_idx, batch in enumerate(dataloader_train):
            pc_pairs = [pair.to(device) for pair in batch]
            optimizer.zero_grad()

            (p_green_R1, p_red_R1), (p_green_R2, p_red_R2) = model(pc_pairs)
            
            # Assuming we have ground truth rotations
            R_green_gt1, R_red_gt1 = get_gt_v(ground_truth_rotations([pc.rot_1 for pc in pc_pairs]))  # Function to get ground truth rotation vectors
            R_green_gt2, R_red_gt2 = get_gt_v(ground_truth_rotations([pc.rot_2 for pc in pc_pairs]))  # Function to get ground truth rotation vectors
            
            pred_list1 = {
                "Rot1": p_green_R1,
                "Rot2": p_red_R1,
            }
            gt_list1 = {
                "Rot1": R_green_gt1.cuda(),
                "Rot2": R_red_gt1.cuda(),
            }
            
            pred_list2 = {
                "Rot1": p_green_R2,
                "Rot2": p_red_R2,
            }
            gt_list2 = {
                "Rot1": R_green_gt2.cuda(),
                "Rot2": R_red_gt2.cuda(),
            }

            sym1, sym2 = get_sym_from_input(pc_pairs)

            loss_dict1 = criterion(pred_list1, gt_list1, sym1)
            loss_dict2 = criterion(pred_list2, gt_list2, sym2)
            loss = (loss_dict1['Rot1'] + loss_dict1['Rot2'] + loss_dict2['Rot1'] + loss_dict2['Rot2']) / 2.0
            loss.backward()
            optimizer.step()

            total_loss += loss.item()
            global_step += 1

            # 每10个batch记录一次loss
            if (batch_idx + 1) % 10 == 0:
                writer.add_scalar('train/loss', loss.item(), global_step)
                print(f"Epoch:[{epoch + 1}|{num_epochs}],Batch:[{(batch_idx + 1)}|{len(dataloader_train)}],Loss:[{loss.item():.4f}]")

        avg_loss = total_loss / len(dataloader_train)
        print(f"Epoch [{epoch+1}|{num_epochs}],Loss:{avg_loss:.4f}")
        writer.add_scalar('train/avg_loss', avg_loss, epoch)

        # 每5个epoch跑一次测试集
        if (epoch + 1) % 5 == 0:
            torch.save(model.state_dict(), log_dir+r'/'+f"GPV_[{epoch+1}|{num_epochs}].pth")
            test_metrics(model, dataloader_test_inter, device, writer, epoch, 'test_inter')
            test_metrics(model, dataloader_test_intra, device, writer, epoch, 'test_intra')


def test_metrics(model, dataloader, device, writer, epoch, phase):
    print("______________________" + phase + "_______________________")
    model.eval()
    all_pred_rot_matrices = []
    all_gt_rot_matrices = []
    with torch.no_grad():
        for batch in tqdm(dataloader):
            pc_pairs = [pair.to(device) for pair in batch]
            (p_green_R1, p_red_R1), (p_green_R2, p_red_R2) = model(pc_pairs)
            
            # Assuming we have ground truth rotations
            R_green_gt1, R_red_gt1 = get_gt_v(ground_truth_rotations([pc.rot_1 for pc in pc_pairs]))  # Function to get ground truth rotation vectors
            R_green_gt2, R_red_gt2 = get_gt_v(ground_truth_rotations([pc.rot_2 for pc in pc_pairs]))  # Function to get ground truth rotation vectors
            
            # Convert predicted vectors and ground truth vectors back to rotation matrices
            pred_rot_matrices1 = vectors_to_rotation_matrix(p_green_R1, p_red_R1)
            pred_rot_matrices2 = vectors_to_rotation_matrix(p_green_R2, p_red_R2)
            gt_rot_matrices1 = vectors_to_rotation_matrix(R_green_gt1, R_red_gt1)
            gt_rot_matrices2 = vectors_to_rotation_matrix(R_green_gt2, R_red_gt2)
            
            # Store predictions and ground truths for metrics calculation
            all_pred_rot_matrices.append(pred_rot_matrices1.cpu())
            all_pred_rot_matrices.append(pred_rot_matrices2.cpu())
            all_gt_rot_matrices.append(gt_rot_matrices1.cpu())
            all_gt_rot_matrices.append(gt_rot_matrices2.cpu())
    
    all_pred_rot_matrices = torch.cat(all_pred_rot_matrices, dim=0)
    all_gt_rot_matrices = torch.cat(all_gt_rot_matrices, dim=0)

    mean_rot_error = calculate_pose_metrics(
        all_pred_rot_matrices, all_gt_rot_matrices
    )
    writer.add_scalar(f'{phase}/mean_rot_error', mean_rot_error, epoch)
    print(f"{phase} - Epoch [{epoch+1}]: Mean Rotation Error: {mean_rot_error:.4f}")
    model.train()

In [17]:
model = test_PT3()
dataset_train, dataset_test_intra, dataset_test_inter = get_datasets(root_dir, test_intra_dir, test_inter_dir, voxelization=True, shot=True)
dataloader_train, dataloader_test_intra, dataloader_test_inter = get_dataloaders(dataset_train, dataset_test_intra, dataset_test_inter)
train(model, dataloader_train, dataloader_test_inter, dataloader_test_intra, 0.001, 40, "./log_dir/PointTrnsformer_test_sym_v1")

_________________________train_epoch___________________________
______________________first_test_epoch_________________________
______________________test_inter_______________________


100%|██████████| 2/2 [00:01<00:00,  1.08it/s]


test_inter - Epoch [1]: Mean Rotation Error: 92.3998
______________________test_intra_______________________


100%|██████████| 2/2 [00:01<00:00,  1.30it/s]

test_intra - Epoch [1]: Mean Rotation Error: 108.7847





Epoch [1|40],Loss:1.2760
Epoch [2|40],Loss:0.8810
Epoch [3|40],Loss:0.5089
Epoch [4|40],Loss:0.5419
Epoch [5|40],Loss:0.4973
______________________test_inter_______________________


100%|██████████| 2/2 [00:01<00:00,  1.44it/s]


test_inter - Epoch [5]: Mean Rotation Error: 78.7313
______________________test_intra_______________________


100%|██████████| 2/2 [00:01<00:00,  1.48it/s]

test_intra - Epoch [5]: Mean Rotation Error: 77.6891





Epoch [6|40],Loss:0.3958
Epoch [7|40],Loss:0.4063
Epoch [8|40],Loss:0.4237
Epoch [9|40],Loss:0.3947
Epoch [10|40],Loss:0.3484
______________________test_inter_______________________


100%|██████████| 2/2 [00:01<00:00,  1.53it/s]


test_inter - Epoch [10]: Mean Rotation Error: 83.9853
______________________test_intra_______________________


100%|██████████| 2/2 [00:01<00:00,  1.49it/s]

test_intra - Epoch [10]: Mean Rotation Error: 92.6282





Epoch [11|40],Loss:0.3828
Epoch [12|40],Loss:0.3744
Epoch [13|40],Loss:0.3649
Epoch [14|40],Loss:0.3721
Epoch [15|40],Loss:0.3681
______________________test_inter_______________________


100%|██████████| 2/2 [00:01<00:00,  1.41it/s]


test_inter - Epoch [15]: Mean Rotation Error: 87.6344
______________________test_intra_______________________


100%|██████████| 2/2 [00:01<00:00,  1.52it/s]

test_intra - Epoch [15]: Mean Rotation Error: 90.3215





Epoch [16|40],Loss:0.3425
Epoch [17|40],Loss:0.3592
Epoch [18|40],Loss:0.3289
Epoch [19|40],Loss:0.3412
Epoch [20|40],Loss:0.3220
______________________test_inter_______________________


100%|██████████| 2/2 [00:01<00:00,  1.50it/s]


test_inter - Epoch [20]: Mean Rotation Error: 81.4152
______________________test_intra_______________________


100%|██████████| 2/2 [00:01<00:00,  1.55it/s]

test_intra - Epoch [20]: Mean Rotation Error: 77.8457





Epoch [21|40],Loss:0.3293
Epoch [22|40],Loss:0.3207
Epoch [23|40],Loss:0.3424
Epoch [24|40],Loss:0.3217
Epoch [25|40],Loss:0.3478
______________________test_inter_______________________


100%|██████████| 2/2 [00:01<00:00,  1.15it/s]


test_inter - Epoch [25]: Mean Rotation Error: 76.5967
______________________test_intra_______________________


100%|██████████| 2/2 [00:01<00:00,  1.02it/s]

test_intra - Epoch [25]: Mean Rotation Error: 77.4199





Epoch [26|40],Loss:0.3191
Epoch [27|40],Loss:0.3171
Epoch [28|40],Loss:0.2993
Epoch [29|40],Loss:0.3292
Epoch [30|40],Loss:0.3127
______________________test_inter_______________________


100%|██████████| 2/2 [00:01<00:00,  1.34it/s]


test_inter - Epoch [30]: Mean Rotation Error: 76.9464
______________________test_intra_______________________


100%|██████████| 2/2 [00:01<00:00,  1.36it/s]

test_intra - Epoch [30]: Mean Rotation Error: 76.9780





Epoch [31|40],Loss:0.3679
Epoch [32|40],Loss:0.3690
Epoch [33|40],Loss:0.3908
Epoch [34|40],Loss:0.3692
Epoch [35|40],Loss:0.3414
______________________test_inter_______________________


100%|██████████| 2/2 [00:01<00:00,  1.26it/s]


test_inter - Epoch [35]: Mean Rotation Error: 78.3154
______________________test_intra_______________________


100%|██████████| 2/2 [00:01<00:00,  1.36it/s]

test_intra - Epoch [35]: Mean Rotation Error: 76.3035





Epoch [36|40],Loss:0.3542
Epoch [37|40],Loss:0.3304
Epoch [38|40],Loss:0.2493
Epoch [39|40],Loss:0.3107
Epoch [40|40],Loss:0.2589
______________________test_inter_______________________


100%|██████████| 2/2 [00:01<00:00,  1.34it/s]


test_inter - Epoch [40]: Mean Rotation Error: 68.2631
______________________test_intra_______________________


100%|██████████| 2/2 [00:01<00:00,  1.34it/s]

test_intra - Epoch [40]: Mean Rotation Error: 107.8630





In [5]:
model = PointTransformerV3().cuda()
dataset_train, dataset_test_intra, dataset_test_inter = get_datasets(root_dir, test_intra_dir, test_inter_dir, voxelization=True, shot=True)
dataloader_train, dataloader_test_intra, dataloader_test_inter = get_dataloaders(dataset_train, dataset_test_intra, dataset_test_inter)
pairs = next(iter(dataloader_train))
pc_pairs = [pair.to("cuda") for pair in pairs]
pc1s = [pc_pair.pc1 for pc_pair in pc_pairs]
pc2s = [pc_pair.pc2 for pc_pair in pc_pairs]
bs = len(pc_pairs)
pc_batch_1: PointCloudBatch = PointCloud.collate(pc1s)
# pc_batch_2: PointCloudBatch = PointCloud.collate(pc2s)
# pc_batch_1.voxel_tensor.features
in_dict_1 = {
    "feat": pc_batch_1.voxel_tensor.features,
    "grid_coord": pc_batch_1.voxel_tensor.indices[:, 1:],
    "batch": pc_batch_1.voxel_tensor.indices[:, 0],
    "coord": pc_batch_1.voxel_tensor.indices[:, 1:] * 0.01,
    "grid_size": 0.01
    # "sparse_shape": pc_batch_1.voxel_tensor.spatial_shape,
    # "sparse_conv_feat": pc_batch_1.voxel_tensor,
}
outs = model(in_dict_1)
# # print(model)
# features = st.features
# indices = st.indices
# spatial_shape = st.spatial_shape
# batch_size = st.batch_size

# # Extract batch and grid coordinates from indices
# batch = indices[:, 0]
# grid_coord = indices[:, 1:]

# # Construct the dictionary for initializing `Point`
# input_dict = {
#     "feat": features,
#     "grid_coord": grid_coord,
#     "batch": batch,
#     "sparse_shape": spatial_shape,
#     "sparse_conv_feat": st
# }


In [9]:
print(outs['feat'][pc_batch_1.pc_voxel_id].shape)

torch.Size([32000, 64])


In [None]:
ts = spconv.SparseConvTensor()
ts.features
