In [None]:
import os
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader
from plyfile import PlyData
from tqdm import tqdm
from models.pointnet2_utils import PointNetSetAbstraction, PointNetFeaturePropagation

In [None]:
class HumanPointCloudDataset(Dataset):
    def __init__(self, data_dir, num_points=1024):
        self.sequences = []
        self.num_points = num_points

        for folder in sorted(os.listdir(data_dir)):
            folder_path = os.path.join(data_dir, folder)
            if os.path.isdir(folder_path):
                files = sorted([os.path.join(folder_path, f) for f in os.listdir(folder_path) if f.endswith('.ply')])
                if len(files) == 16:  # Ensure each sequence has 16 frames
                    self.sequences.append(files)

    def __len__(self):
        return len(self.sequences)

    def __getitem__(self, idx):
        sequence_files = self.sequences[idx]
        point_clouds = [self.load_ply(file) for file in sequence_files]
        return torch.stack(point_clouds, dim=0)  # Shape: [T, N, 3]

    def load_ply(self, file_path):
        plydata = PlyData.read(file_path)
        vertex = plydata['vertex']
        points = np.c_[vertex['x'], vertex['y'], vertex['z']].astype(np.float32)
        if points.shape[0] < self.num_points:
            padding = np.zeros((self.num_points - points.shape[0], 3), dtype=np.float32)
            points = np.vstack([points, padding])
        elif points.shape[0] > self.num_points:
            points = points[:self.num_points, :]
        return torch.tensor(points).float()


In [None]:
class PointNet2FeatureExtractor(nn.Module):
    def __init__(self):
        super(PointNet2FeatureExtractor, self).__init__()
        self.sa1 = PointNetSetAbstraction(1024, 0.1, 32, 3, [32, 32, 64], group_all=False)
        self.sa2 = PointNetSetAbstraction(256, 0.2, 32, 67, [64, 64, 128], group_all=False)
        self.sa3 = PointNetSetAbstraction(64, 0.4, 32, 131, [128, 128, 256], group_all=False)
        self.sa4 = PointNetSetAbstraction(16, 0.8, 32, 259, [256, 256, 512], group_all=False)
        self.fp1 = PointNetFeaturePropagation(512, [256, 128, 3])  # Example feature propagation for reconstruction

    def forward(self, xyz):
        l0_xyz = xyz.permute(0, 2, 1)  # [B, 3, N]

        l1_xyz, l1_points = self.sa1(l0_xyz, None)

        l2_xyz, l2_points = self.sa2(l1_xyz, l1_points)

        l3_xyz, l3_points = self.sa3(l2_xyz, l2_points)

        l4_xyz, l4_points = self.sa4(l3_xyz, l3_points)

        l4_points = l4_points.mean(dim=2)  # Global average pooling across points dimension to [B, 512]

        return l4_points
    
    def decode(self, features, num_points):
        # Create dummy points to satisfy the `points1` and `points2` requirements of the `fp1` method
        B = features.size(0)
        dummy_points = torch.randn(B, 3, num_points).to(features.device)
        x = self.fp1(dummy_points, None, features.unsqueeze(-1).repeat(1, 1, num_points), None)
        return x

In [None]:
class SequenceModel(nn.Module):
    def __init__(self, input_dim, hidden_dim, num_layers):
        super(SequenceModel, self).__init__()
        self.lstm = nn.LSTM(input_dim, hidden_dim, num_layers, batch_first=True)
        self.fc = nn.Linear(hidden_dim, input_dim)

    def forward(self, x):
        output, _ = self.lstm(x)  # x: [B, T, input_dim]
        return self.fc(output)  # [B, T, input_dim]

# Chamfer Distance
def chamfer_distance(pc1, pc2):
    diff1 = torch.cdist(pc1, pc2, p=2)
    diff2 = torch.cdist(pc2, pc1, p=2)
    cd = diff1.min(dim=-1)[0].sum(dim=-1) + diff2.min(dim=-1)[0].sum(dim=-1)
    return cd.mean()


In [None]:
def train_model(pointnet2, seq_model, dataloader, optimizer, epochs, device, save_path):
    pointnet2.train()
    seq_model.train()
    criterion = chamfer_distance
    best_loss = float('inf')

    for epoch in range(epochs):
        epoch_loss = 0
        with tqdm(total=len(dataloader), desc=f"Epoch {epoch + 1}/{epochs}", unit="batch") as pbar:
            for batch in dataloader:
                batch = batch.to(device)  # [B, T, N, 3]
                B, T, N, _ = batch.size()

                features_list = []
                for t in range(T):
                    ft = pointnet2(batch[:, t, :, :])  # [B, 512]
                    features_list.append(ft)
                features = torch.stack(features_list, dim=1)  # [B, T, 512]

                optimizer.zero_grad()
                output_features = seq_model(features)  # [B, T, 512]

                loss = criterion(output_features.view(B, T, -1), features.view(B, T, -1))
                loss.backward()
                optimizer.step()

                epoch_loss += loss.item()
                pbar.set_postfix(loss=loss.item())
                pbar.update(1)

        avg_loss = epoch_loss / len(dataloader)
        print(f"Epoch {epoch + 1} complete. Average Loss: {avg_loss:.6f}")

        if avg_loss < best_loss:
            best_loss = avg_loss
            torch.save({
                'pointnet2': pointnet2.state_dict(),
                'seq_model': seq_model.state_dict(),
                'optimizer': optimizer.state_dict(),
                'epoch': epoch,
                'loss': best_loss
            }, save_path)
            print(f"Saved best model with loss: {best_loss:.6f}")


In [None]:
def save_point_cloud(points, file_path):
    os.makedirs(os.path.dirname(file_path), exist_ok=True)
    with open(file_path, 'w') as f:
        f.write("ply\n")
        f.write("format ascii 1.0\n")
        f.write(f"element vertex {points.shape[0]}\n")
        f.write("property float x\n")
        f.write("property float y\n")
        f.write("property float z\n")
        f.write("end_header\n")
        for point in points:
            f.write(f"{point[0]} {point[1]} {point[2]}\n")

# Main execution
data_dir = "your path"  # Update with your dataset path
dataset = HumanPointCloudDataset(data_dir)
dataloader = DataLoader(dataset, batch_size=4, shuffle=True, num_workers=4)

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
pointnet2 = PointNet2FeatureExtractor().to(device)
seq_model = SequenceModel(input_dim=512, hidden_dim=256, num_layers=2).to(device)
optimizer = torch.optim.Adam(list(pointnet2.parameters()) + list(seq_model.parameters()), lr=1e-3)

train_model(pointnet2, seq_model, dataloader, optimizer, epochs=50, device=device, save_path="best_model.pth")

print("Training completed.")

In [None]:
# Decode features to point cloud
def decode_features_to_point_cloud(features, pointnet2, num_points=1024):
    B, T, feature_dim = features.shape
    decoded_points = []
    for t in range(T):
        x = features[:, t, :]
        x = pointnet2.decode(x, num_points)  # Decode using the added decode method
        decoded_points.append(x.permute(0, 2, 1))  # [B, N, 3]
    return torch.stack(decoded_points, dim=1)  # [B, T, N, 3]

# Save decoded point clouds
output_dir = "your path"
pointnet2.eval()
seq_model.eval()

with torch.no_grad():
    for i, batch in enumerate(dataloader):
        batch = batch.to(device)  # [B, T, N, 3]
        B, T, N, D = batch.size()
        features_list = []
        for t in range(T):
            ft = pointnet2(batch[:, t, :, :])  # [B, 512]
            features_list.append(ft)
        features = torch.stack(features_list, dim=1)  # [B, T, 512]
        
        output_features = seq_model(features)  # [B, T, 512]

        reconstructed_sequence = decode_features_to_point_cloud(output_features, pointnet2, num_points=N)  # [B, T, N, 3]

        for j, sequence in enumerate(reconstructed_sequence):
            for timestep, points in enumerate(sequence):
                save_point_cloud(points.cpu().numpy(), os.path.join(output_dir, f"sample_{i * 4 + j}_timestep_{timestep}.ply"))

print("Point cloud reconstruction completed.")