In [None]:
from pathlib import Path

from torch_geometric.datasets import ModelNet
import torch_geometric.transforms as T

current_path = Path.cwd()
dataset_dir = current_path / "modelnet10"

pre_transform = T.Compose([
    T.SamplePoints(1024, remove_faces=True, include_normals=True),
    T.NormalizeScale(),
])

train_dataset = ModelNet(dataset_dir, name="10", train=True, transform=None, pre_transform=pre_transform, pre_filter=None)
test_dataset = ModelNet(dataset_dir, name="10", train=False, transform=None, pre_transform=pre_transform, pre_filter=None)

In [None]:
from torch_geometric.data import DataLoader as DataLoader
dataloader = DataLoader(train_dataset, batch_size=32, shuffle=False)
batch = next(iter(dataloader))
print(batch)

In [None]:
from torch_geometric.nn import knn

assign_index = knn(x=batch.pos, y=batch.pos, k=16, batch_x=batch.batch, batch_y=batch.batch)
print(assign_index.shape)
print(assign_index)

In [None]:
p = batch.pos[assign_index[1, :], :]
q = batch.pos[assign_index[0, :], :]
print(p.shape, q.shape)

In [None]:
import torch
from torch_geometric.nn import max_pool_x
import torch.nn as nn

class EdgeConv(nn.Module):
    def __init__(self):
        super(EdgeConv, self).__init__()
        self.shared_mlp = nn.Sequential(
            nn.Linear(6, 64), nn.BatchNorm1d(64), nn.LeakyReLU(negative_slope=0.2)
        )
        
    def forward(self, batch):
        assign_index = knn(x=batch.pos, y=batch.pos, k=20, batch_x=batch.batch, batch_y=batch.batch)
        p = batch.pos[assign_index[1, :], :]
        q = batch.pos[assign_index[0, :], :]
        x = torch.cat([p, q-p], dim=1)
        x = self.shared_mlp(x)
        
        edge_batch = batch.batch[assign_index[1, :]]
        x, _ = max_pool_x(cluster=assign_index[1, :], x=x, batch=edge_batch)
        return x

f = EdgeConv()
y = f(batch)
print(y.shape)