In [4]:
import torch
import numpy as np

In [13]:
def knn(data, k=5)->torch.Tensor:
    """Construct edge feature for each point
    Args:
      point_cloud: (batch_size, num_points, num_dims)
      k: int number of neighbours

    Returns:
      idx: shape:(batch_size, num_points, nums_neighours,)
    """
    dists_matrix = torch.cdist(data, data)
    print(dists_matrix.shape)
    _, idx = dists_matrix.topk(k+1, dim=-1, largest=False)  # +1 the point itself is included
    return idx[...,1:] # not include the point itself



def get_edge_feature(point_cloud, idx=None, k=20,device="cpu"):
    """Construct edge feature for each point
    Args:
      point_cloud: (batch_size, num_points, num_dims)
      idx: (batch_size, num_points, neighbours)
      k: int
      device: cpu/cuda

    Returns:
      features: (batch_size, num_points, k, num_dims)
    """
    point_cloud = point_cloud.to(device)
    batch_size = point_cloud.shape[0]
    num_points = point_cloud.shape[1]

    if(idx==None):
        idx = knn(point_cloud,k=k) # (batch_size, num_points, nums_neighours,)

    idx_base = torch.arange(0, batch_size, device=device).view(-1, 1, 1) * num_points # create the base index for mapping
    idx = idx.to(device=device)
    idx = idx + idx_base #[0...0...0]->[0...100...200]
    idx=idx.view(-1) # flatten it -> tensor([  0,  56,  25,  ..., 225, 222, 271], device='cuda:0') e.g: [K01,K02,K03,K11,K12,K13...] shape = (B*N*K) 
   
    num_dims = point_cloud.shape[2]

    # feature : turn neighbour index in idx to coordinate
    feature = point_cloud.view(batch_size*num_points, -1)[idx, :] # feature : B*N*F -> BN * F -> (B*N*K) * F
    # feature : reshape into (Batch_size * Num_points *Nums_neigbours * Features)
    feature = feature.view(batch_size, num_points, k, num_dims)
    # pointcloud : create replicate of the self point up to k for match feature - size B*N*K(repeated)*F 
    point_cloud = point_cloud.view(batch_size, num_points, 1, num_dims).repeat(1, 1, k, 1) 

    # feature size B*N*K*F -> B*N*K*2F (feature-x || x)
    feature = torch.cat((feature-point_cloud, point_cloud), dim=3)

    # todo (B * 2F * N * K) for conv each coordinate(F)
    feature=feature.permute(0,3,1,2)

    return feature

# Example usage:
data = torch.rand((3,100, 3))  # 100 points in 20D (batch_size, num_points, num_dims)
#neighbors = knn(data, k=4)
edges= get_edge_feature(data)
print(edges.shape)
print(type(edges))

torch.Size([3, 100, 100])
torch.Size([3, 6, 100, 20])
<class 'torch.Tensor'>


### **Edgeconv**

In [20]:
import torch
import torch.nn as nn
import torch.nn.functional as F

class EdgeConv(nn.Module):
    def __init__(self, in_channels, out_channels, num_neighbours=20,device="cpu"):
        """Setup EdgeConv
        Args:
        in_channels: int
        out_channels: int
        num_neighbours: int
        """
        super(EdgeConv, self).__init__()
        self.device=device
        self.k= num_neighbours
        self.conv = nn.Sequential(
            nn.Conv2d(in_channels=in_channels*2, out_channels=out_channels, kernel_size=1, bias=False,device=self.device),
            nn.BatchNorm2d(out_channels,device=self.device),
            nn.LeakyReLU(negative_slope=0.2)
        )

    def forward(self,x):
        """Setup EdgeConv
        Args:
        x: shape - (batch_size, num_points, num_dims)

        Returns:
        features: (batch_size, num_dims, num_points, num_neigbours)
        """
        x = get_edge_feature(x, k=self.k,device=self.device)
        x = self.conv(x)
        # pick the largest k (batch_size, 64, num_points, k) -> (batch_size, 64, num_points)
        x = x.max(dim=-1, keepdim=False)[0]
        return x
    
# Example usage:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
data = torch.rand((3,100, 3))  # 100 points in 20D (batch_size, num_points, num_dims)
conv = EdgeConv(3, 64,device=device)
out = conv(data)
print("out.shape=", out.shape)

torch.Size([3, 100, 100])
out.shape= torch.Size([3, 64, 100])


### **DGCNN (Classification)**

In [None]:
class DGCNN(nn.module):
    def __init__(self, num_neighbours=20,device="cpu"):
        super(DGCNN,self).__init__()

    def forward(self,x):
        raise NotImplemented

In [12]:
#disable
import torch
import torch.nn as nn
import torch.nn.functional as F

class EdgeConv(nn.Module):
    def __init__(self, in_channels, out_channels):
        super(EdgeConv, self).__init__()
        self.mlp = nn.Sequential(
            nn.Linear(2*in_channels, out_channels),
            nn.ReLU(),
            nn.Linear(out_channels, out_channels)
        )

    def forward(self, x, edge_index):
        # x has shape [N, in_channels]
        # edge_index has shape [2, E]

        start, end = edge_index

        # Create pseudo-edges by concatenating node features of connected nodes
        edge_features = torch.cat([x[start], x[end]], dim=1)  # shape [E, 2*in_channels]

        return self.mlp(edge_features)  # shape [E, out_channels]

# Example usage:
x = torch.rand((100, 3))  # 100 nodes with 64-dimensional features
edge_index = torch.randint(100, (2, 500))  # 500 edges
print(edge_index.shape)

conv = EdgeConv(3, 128)
out = conv(x, edge_index)

print(out.shape)  # Should be [500, 128]

torch.Size([2, 500])
torch.Size([500, 128])
