In [None]:
import os
import numpy as np
import torch
from torch.utils.data import Dataset, DataLoader
from plyfile import PlyData
import torch.nn as nn
import torch.nn.functional as F
from tqdm import tqdm 
from models.pointnet2_utils import PointNetSetAbstraction, PointNetFeaturePropagation  

In [None]:
class HumanPointCloudDataset(Dataset):
    def __init__(self, data_dir):
        self.data_files = [
            os.path.join(data_dir, f) for f in os.listdir(data_dir) if f.endswith('.ply')
        ]

    def __len__(self):
        return len(self.data_files)

    def __getitem__(self, idx):
        plydata = PlyData.read(self.data_files[idx])
        vertex = plydata['vertex']
        point_cloud = np.c_[vertex['x'], vertex['y'], vertex['z']].astype(np.float32)  # 转换为 float32

        if point_cloud.shape[0] < 1024:
            padding = np.zeros((1024 - point_cloud.shape[0], 3)).astype(np.float32)
            point_cloud = np.vstack([point_cloud, padding])
        elif point_cloud.shape[0] > 1024:
            point_cloud = point_cloud[:1024, :]
        
        return torch.from_numpy(point_cloud).permute(1, 0)  # [3, N]

In [None]:
class PointNet2FeatureExtractor(nn.Module):
    def __init__(self):
        super(PointNet2FeatureExtractor, self).__init__()
        # Encoder (Feature Extraction)
        self.sa1 = PointNetSetAbstraction(1024, 0.1, 32, 6, [32, 32, 64], False)
        self.sa2 = PointNetSetAbstraction(256, 0.2, 32, 67, [64, 64, 128], False)
        self.sa3 = PointNetSetAbstraction(64, 0.4, 32, 131, [128, 128, 256], False)
        self.sa4 = PointNetSetAbstraction(16, 0.8, 32, 259, [256, 256, 512], False) 
        
        # Decoder (For Reconstruction)
        self.fp4 = PointNetFeaturePropagation(768, [256, 256])
        self.fp3 = PointNetFeaturePropagation(384, [256, 256])
        self.fp2 = PointNetFeaturePropagation(320, [256, 128])
        self.fp1 = PointNetFeaturePropagation(128, [128, 128])

        self.final_conv = nn.Conv1d(128, 3, 1)

    def forward(self, xyz):
        l0_xyz = xyz[:, :3, :]  # [B, 3, N]
        l0_points = xyz  # [B, 3, N]

        # Feature Extraction
        l1_xyz, l1_points = self.sa1(l0_xyz, l0_points)
        l2_xyz, l2_points = self.sa2(l1_xyz, l1_points)
        l3_xyz, l3_points = self.sa3(l2_xyz, l2_points)
        l4_xyz, l4_points = self.sa4(l3_xyz, l3_points)

        # Feature Propagation (Decoding)
        l3_points = self.fp4(l3_xyz, l4_xyz, l3_points, l4_points)
        l2_points = self.fp3(l2_xyz, l3_xyz, l2_points, l3_points)
        l1_points = self.fp2(l1_xyz, l2_xyz, l1_points, l2_points)
        l0_points = self.fp1(l0_xyz, l1_xyz, None, l1_points)
        
        l0_points = self.final_conv(l0_points)  # [B, 3, N]

        return l0_points  # [B, 3, N]

In [None]:

def train_model(model, dataloader, epochs, lr, device):
    optimizer = torch.optim.Adam(model.parameters(), lr=lr)
    model = model.to(device)
    model.train()

    for epoch in range(epochs):
        epoch_loss = 0
        with tqdm(total=len(dataloader), desc=f"Epoch {epoch + 1}/{epochs}", unit="batch") as pbar:
            for point_clouds in dataloader:
                point_clouds = point_clouds.to(device)  # [B, 3, N]
                optimizer.zero_grad()

                reconstructed = model(point_clouds)  # [B, 3, N]
                loss = F.mse_loss(reconstructed, point_clouds)  # Simple reconstruction loss
                loss.backward()
                optimizer.step()

                epoch_loss += loss.item()
                pbar.set_postfix(loss=loss.item())
                pbar.update(1)
        
        avg_loss = epoch_loss / len(dataloader)
        print(f"Epoch {epoch + 1} complete. Average Loss: {avg_loss}")

    print("Training Complete")
    torch.save(model.state_dict(), "pointnet2_feature_extractor.pth")

data_dir = "/home/jerry/Pointnet_Pointnet2_pytorch/data" 
dataset = HumanPointCloudDataset(data_dir)
dataloader = DataLoader(dataset, batch_size=16, shuffle=True, num_workers=4)

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model = PointNet2FeatureExtractor()
train_model(model, dataloader, epochs=50, lr=1e-3, device=device)

In [None]:
def extract_spatial_features(model, dataloader, device):
    model = model.to(device)
    model.eval()
    features = []

    with torch.no_grad():
        for point_clouds in tqdm(dataloader, desc="Extracting Spatial Features"):
            point_clouds = point_clouds.to(device)  # [B, 3, N]
            spatial_features = model(point_clouds)  # [B, feature_dim, N]
            features.append(spatial_features.mean(dim=2))  # Global feature, [B, feature_dim]

    return torch.cat(features, dim=0)  # [total_samples, feature_dim]

spatial_model = PointNet2FeatureExtractor()
spatial_model.load_state_dict(torch.load("your model.pth"))

spatial_features = extract_spatial_features(spatial_model, dataloader, device)
torch.save(spatial_features, "your.pt") 