In [3]:
# import torch
# import os
# import torch.nn as nn
import open3d as o3d
# import numpy as np
# from torch.utilities.data import Dataset, DataLoader, random_split
# from torch_snippets.torch_loader import Report



# Dataset & Dataloader

In [3]:

class Read_ply(Dataset):
    def __init__(self, ply_path):
        super(Read_ply, self).__init__()
        self.ply_paths = ply_path
    
    def __len__(self):
        return len(self.ply_paths)
    
    def __getitem__(self, index):
        points = np.asarray(o3d.io.read_point_cloud(self.ply_paths[index]).points)
        return points

    def collate_fn(self, points):
        max_points = max([p.shape[0] for p in points])  # Find the maximum number of points in the batch
        
        # Pad or truncate each point cloud to the same size (max_points)
        padded_points = []
        for point_cloud in points:
            if point_cloud.shape[0] < max_points:
                # Pad the point cloud with zeros to match the maximum size
                padded = np.pad(point_cloud, ((0, max_points - point_cloud.shape[0]), (0, 0)), mode='constant')
                padded_points.append(padded)

        # Stack the point clouds to form a batch
        points = np.stack(padded_points, axis=0)  # (B, N, 3)

        # Create the target as the next frame
        target = np.roll(points, shift=-1, axis=0)  # Shift the points by one frame along the batch dimension
        target[-1] = points[-1]  # Set the last target to the last frame (no next frame for it)

        # Convert to tensors
        points = torch.tensor(points).float()
        target = torch.tensor(target).float()

        return points, target              

def RandomSplit(datasets, train_set_percentage):
    lengths = [int(len(datasets)*train_set_percentage), len(datasets)-int(len(datasets)*train_set_percentage)]
    return random_split(datasets, lengths)


def GetDataLoader(ply_paths, batch_size, train_set_percentage=0.9, shuffle=True, drop_last=True):
    # Defining the dataset
    ds = Read_ply(ply_paths)
    # Randomly splitting the dataset
    train_set, val_set = RandomSplit(ds, train_set_percentage)

    # Defining the dataloader
    val_dl = DataLoader(val_set, batch_size=batch_size, shuffle=shuffle, drop_last=drop_last, collate_fn=ds.collate_fn)
    train_dl = DataLoader(train_set, batch_size=batch_size, shuffle=shuffle, drop_last=drop_last, collate_fn=ds.collate_fn)
    
    return train_dl, val_dl

In [4]:
ply_paths = [f'data/ped_4ff8af4d-6840-47c2-bc9b-eb383009ad65/frame_{i}.ply' for i in range(0,20,2)]

train_dl, val_dl = GetDataLoader(ply_paths, 2)

for batch in train_dl:
    print(batch[0].shape, batch[1].shape)
    break


torch.Size([1, 2796887, 3]) torch.Size([1, 2796887, 3])


# Attention Vector Architecture

In [23]:

class MLP(nn.Module):
    def __init__(self, hidden_size, embed_dim):
        super(MLP, self).__init__()
        self.mlp = nn.Sequential(
            nn.Linear(3, hidden_size),
            nn.GELU(),
            nn.Linear(hidden_size, embed_dim)  # Transform to higher-dimensional space
        )

    def forward(self, x):
        """
        Input shape x: (batch_size, num_points, 3)
        Output shape: (batch_size, num_points, embed_dim)
        """
        out = self.mlp(x)
        return out


class AttentionMLP(nn.Module):
    def __init__(self, embed_dim, hidden_size):
        super(AttentionMLP, self).__init__()
        self.att_mlp = nn.Sequential(
            nn.Linear(embed_dim, hidden_size),
            nn.GELU(),
            nn.Linear(hidden_size, 1), # Scalar score  for each point
            nn.Softmax(dim=1) # Normalize scores across points
        )

    def forward(self, x):
        """
        Input shape x: (batch_size, num_points, embed_dim)
        Output shape: (batch_size, num_points, 1)
        """
        out = self.att_mlp(x)
        return out
    
class AttentionVector(nn.Module):
    def __init__(self, embed_dim, hidden_size, num_heads=2, dropout=0.1):
        super(AttentionVector, self).__init__()
        self.feature_mlp = MLP(hidden_size, embed_dim)
        self.att_mlp = AttentionMLP(embed_dim, hidden_size)

    def forward(self, x):
        """
        Input shape x: (batch_size, num_points, 3)
        Output shape: (batch_size, num_points, embed_dim)
        """
        features = self.feature_mlp(x) # -> (batch_size, num_points, embed_dim)
        attention_scores = self.att_mlp(features) # -> (batch_size, num_points, 1)
        wei = attention_scores * features # -> (batch_size, num_points, embed_dim)

        return wei 



## Initializing the model


In [24]:
model = AttentionVector(embed_dim=16, hidden_size=32)
# optimizer = torch.optim.Adam(model.parameters(), lr=0.001)
# criterion = nn.MSELoss()

print(sum(p.numel() for p in model.parameters()), "parameters")


1233 parameters


# Not Necessary

In [17]:
def train_batch(model, optimizer, criterion, tr_dl):
    total_loss = 0
    model.train()
    for input, target in tr_dl:    
        output = model(input)
        loss = criterion(output, target)
        optimizer.zero_grad(set_to_none= True)
        loss.backward()
        optimizer.step()
        total_loss += loss.item()
        average_loss = total_loss / len(tr_dl)

    return average_loss

@torch.no_grad()
def val_batch(model, criterion, val_dl):
    total_loss = 0
    model.eval()
    for input, target in val_dl:
        output = model(input)
        loss = criterion(output, target)
        total_loss += loss.item()
        average_loss = total_loss / len(val_dl)

    return average_loss

In [18]:
n_epochs = 5
log = Report(n_epochs)
for epoch in range(n_epochs):
  N=len(train_dl)

  for ix, _ in enumerate(train_dl):
    avg_loss = train_batch(model, optimizer, criterion, train_dl)
    log.record(epoch+(ix+1)/N, trn_loss= avg_loss, end='\r')

  val_loss=0
  N=len(val_dl)
  for ix, _ in enumerate(val_dl):

    loss= val_batch(model, criterion, val_dl)
    val_loss+= loss
    log.record(epoch+(ix+1)/N, val_loss=loss, end='\r')

  log.report_avgs(epoch+1)
log.plot_epochs(['trn_loss', 'val_loss'])

EPOCH: 1.000  trn_loss: 45.607  (87.42s - 349.68s remaining)
EPOCH: 2.000  trn_loss: 45.074  (178.77s - 268.16s remaining)
EPOCH: 3.000  trn_loss: 45.562  (366.82s - 244.54s remaining)
EPOCH: 3.250  trn_loss: 46.449  (399.30s - 215.01s remaining)

KeyboardInterrupt: 

In [None]:
pcd_points = o3d.io.read_point_cloud(r"data\ped_4ff8af4d-6840-47c2-bc9b-eb383009ad65\frame_0.ply").points

pcd_points = torch.tensor(np.asarray(pcd_points), dtype=torch.float32)

out = model(pcd_points)
print(out.shape)


torch.Size([2796887, 3])


In [22]:
np_points = (out).detach().numpy()
np_points=np_points.astype(np.float64)
print(np_points.shape, np_points.dtype)

points = o3d.utility.Vector3dVector(np_points)
pcd = o3d.geometry.PointCloud()
pcd.points = points
o3d.visualization.draw_geometries([pcd])

(2796887, 3) float64
