In [1]:
import torch_geometric
from torch_geometric.datasets import ModelNet
from torch_geometric.transforms import SamplePoints
from torch_geometric.transforms import NormalizeScale
from torch_geometric.loader import DataLoader
import torch_geometric.utils as utils
import torch_geometric.nn.conv as conv
from torch_geometric.transforms import Compose


num_points = 1024
batch_size = 16
modelnet_num = 10

transforms = Compose([SamplePoints(num_points, include_normals=True), NormalizeScale()])

root = "/home/victor/workspace/thesis_ws/datasets/Modelnet" + str(modelnet_num)
print(root)
dataset_train = ModelNet(root=root, name=str(modelnet_num), train=True, transform=transforms)
dataset_test = ModelNet(root=root, name=str(modelnet_num), train=False, transform=transforms)

# Verification...
print(f"Train dataset shape: {dataset_train}")
print(f"Test dataset shape:  {dataset_test}")

print(dataset_train[0])

dataloader_train = DataLoader(dataset_train, batch_size=batch_size, shuffle=True, pin_memory=True)
dataloader_test  = DataLoader(dataset_test, batch_size=batch_size)

/home/victor/workspace/thesis_ws/datasets/Modelnet10
Train dataset shape: ModelNet10(3991)
Test dataset shape:  ModelNet10(908)
Data(pos=[1024, 3], y=[1], normal=[1024, 3])


In [2]:
import torch.nn as nn
class GetGraph(nn.Module):
    def __init__(self):
        """
        Creates the weighted adjacency matrix 'W'
        Taked directly from RGCNN
        """

        super(GetGraph, self).__init__()

    def forward(self, point_cloud, batch):
        point_cloud = point_cloud.reshape(batch_size, -1, 6)
        point_cloud_transpose = point_cloud.permute(0, 2, 1)
        point_cloud_inner = torch.matmul(point_cloud, point_cloud_transpose)
        point_cloud_inner = -2 * point_cloud_inner
        point_cloud_square = torch.sum(torch.mul(point_cloud, point_cloud), dim=2, keepdim=True)
        point_cloud_square_tranpose = point_cloud_square.permute(0, 2, 1)
        adj_matrix = point_cloud_square + point_cloud_inner + point_cloud_square_tranpose
        adj_matrix = torch.exp(-adj_matrix)
        edge_index, edge_weight = utils.dense_to_sparse(adj_matrix)
        return edge_index, edge_weight

def get_graph(point_cloud, batch):
    point_cloud = point_cloud.reshape(batch.unique().shape[0], -1, 6)
    point_cloud_transpose = point_cloud.permute(0, 2, 1)
    point_cloud_inner = torch.matmul(point_cloud, point_cloud_transpose)
    point_cloud_inner = -2 * point_cloud_inner
    point_cloud_square = torch.sum(torch.mul(point_cloud, point_cloud), dim=2, keepdim=True)
    point_cloud_square_tranpose = point_cloud_square.permute(0, 2, 1)
    adj_matrix = point_cloud_square + point_cloud_inner + point_cloud_square_tranpose
    adj_matrix = torch.exp(-adj_matrix)
    edge_index, edge_weight = utils.dense_to_sparse(adj_matrix)

    return edge_index, edge_weight

def get_graph_v2(point_cloud, batch):
    point_cloud = point_cloud.reshape(batch.unique().shape[0], -1, 6)
    adj_matrix = torch.exp(-(torch.sum(torch.mul(point_cloud, point_cloud), dim=2, keepdim=True) - 2 * torch.matmul(point_cloud, point_cloud.permute(0, 2, 1)) + torch.sum(torch.mul(point_cloud, point_cloud), dim=2, keepdim=True).permute(0, 2, 1)))

    return utils.dense_to_sparse(adj_matrix)




In [3]:
import torch

if False:
    trainiter = iter(dataloader_train)
    data = trainiter.next()
    point_cloud = torch.cat([data.pos, data.normal], axis=1)
    point_cloud = point_cloud.to("cuda")
    batch = data.batch.to("cuda") 
    #print(torch.cuda.memory_allocated(device="cuda"))
    graph1 = get_graph(point_cloud, batch)
    #print(torch.cuda.memory_allocated(device="cuda"))
    del graph1
    graph2 = get_graph(point_cloud, batch)
    #print(torch.cuda.memory_allocated(device="cuda"))
    del graph2

    del point_cloud, batch
    #print(torch.cuda.memory_allocated(device="cuda"))


In [4]:
from torch import nn
from torch_geometric.nn.conv import ChebConv
from torch.nn import Linear
from torch_cluster import knn_graph
from torch_geometric.nn import global_max_pool

class RGCNN_model(nn.Module):
    def __init__(self):
        super(RGCNN_model, self).__init__()
        self.conv1  = ChebConv(6, 128, 6)
        self.fc1    = Linear(128, modelnet_num)
        self.relu   = nn.ReLU()
        #self.get_graph = GetGraph()

    def forward(self, x, batch):
        edge_index, edge_weight = get_graph(x, batch=batch)
        out = self.conv1(x=x, edge_index=edge_index, edge_weight=edge_weight, batch=batch)
        out = self.relu(out)
        out = global_max_pool(out, batch)
        out = self.fc1(out)
        return out

model = RGCNN_model()
print(model)

RGCNN_model(
  (conv1): ChebConv(6, 128, K=6, normalization=sym)
  (fc1): Linear(in_features=128, out_features=10, bias=True)
  (relu): ReLU()
)


In [5]:
import torch
device = "cuda"
model = model.to(device)
optimizer = torch.optim.Adam(model.parameters(), lr=0.01)
criterion = torch.nn.CrossEntropyLoss()  # Define loss criterion.
def train(model, optimizer, loader):
    model.train()
    total_loss = 0
    for data in loader:
        optimizer.zero_grad()
        x = torch.cat([data.pos, data.normal], dim=1)
        logits  = model(x.to(device),  data.batch.to(device))
        loss    = criterion(logits, data.y.to(device))
        loss.backward()
        optimizer.step()
        total_loss += loss.item() * data.num_graphs

    return total_loss / len(loader.dataset)

@torch.no_grad()
def test(model, loader):
    model.eval()

    total_correct = 0
    for data in loader:
        x = torch.cat([data.pos, data.normal], dim=1)
        logits = model(x.to(device), data.batch.to(device))
        pred = logits.argmax(dim=-1)
        total_correct += int((pred == data.y.to(device)).sum())

    return total_correct / len(loader.dataset)

for epoch in range(1, 51):
    loss = train(model, optimizer, dataloader_train)
    test_acc = test(model, dataloader_test)
    print(f'Epoch: {epoch:02d}, Loss: {loss:.4f}, Test Accuracy: {test_acc:.4f}')


Epoch: 01, Loss: 0.7230, Test Accuracy: 0.7808
Epoch: 02, Loss: 0.4011, Test Accuracy: 0.8018
Epoch: 03, Loss: 0.3330, Test Accuracy: 0.8304
Epoch: 04, Loss: 0.3116, Test Accuracy: 0.7952
Epoch: 05, Loss: 0.2842, Test Accuracy: 0.8469
