Note to self: use CONDA Env

In [8]:
import torch
import torch.nn as nn
import torch_geometric
from torch_geometric.datasets import ModelNet
from torch_geometric.transforms import SamplePoints
from torch_geometric.transforms import Compose
from torch_geometric.transforms import LinearTransformation
from torch_geometric.transforms import GenerateMeshNormals
from torch_geometric.transforms import NormalizeScale
from torch_geometric.loader import DataLoader
from torch_geometric.data import Batch
from torch_scatter import scatter_mean
import torch_geometric.nn.conv as conv
from torch_geometric import utils

import sys
import numpy as np

In [9]:
num_points = 216
batch_size = 32
transforms = Compose([SamplePoints(num_points, include_normals=True), NormalizeScale()])
dataset_train = ModelNet(root="data/ModelNet10", name='10', train=True, transform=transforms)
dataset_test = ModelNet(root="data/ModelNet10", name='10', train=False, transform=transforms)
print(dataset_train[0])

Data(pos=[216, 3], y=[1], normal=[216, 3])


In [10]:
loader_train = DataLoader(dataset=dataset_train, batch_size=batch_size, shuffle=True)
loader_iter = iter(loader_train)

data = loader_iter.next()
print(data)
print(data.batch)

DataBatch(pos=[6912, 3], y=[32], normal=[6912, 3], batch=[6912], ptr=[33])
tensor([ 0,  0,  0,  ..., 31, 31, 31])


In [11]:
batch = loader_iter.next()
batch_pos = batch.pos
batch_normal = batch.normal
# Concatenating the position and normals
batch_X = np.append(batch_pos.numpy(), batch_normal.numpy(), axis=1)
print(batch_X.shape)


(6912, 6)


In [12]:
batch_X_aux = batch_X.reshape(batch_size, num_points, 6)
print(batch_X_aux.shape)
batch_X_re = batch_X_aux.reshape([32 * 216, 6])
print(batch_X_re.shape)
print(batch_X.shape)
print(batch_X == batch_X_re)

(32, 216, 6)
(6912, 6)
(6912, 6)
[[ True  True  True  True  True  True]
 [ True  True  True  True  True  True]
 [ True  True  True  True  True  True]
 ...
 [ True  True  True  True  True  True]
 [ True  True  True  True  True  True]
 [ True  True  True  True  True  True]]


In [13]:
class GetGraph(nn.Module):
    def __init__(self):
        """
        Creates the weighted adjacency matrix 'W'
        Taked directly from RGCNN
        """
        super(GetGraph, self).__init__()

    def forward(self, point_cloud):
        point_cloud_transpose = point_cloud.permute(0, 2, 1)
        point_cloud_inner = torch.matmul(point_cloud, point_cloud_transpose)
        point_cloud_inner = -2 * point_cloud_inner
        point_cloud_square = torch.sum(torch.mul(point_cloud, point_cloud), dim=2, keepdim=True)
        point_cloud_square_tranpose = point_cloud_square.permute(0, 2, 1)
        adj_matrix = point_cloud_square + point_cloud_inner + point_cloud_square_tranpose
        adj_matrix = torch.exp(-adj_matrix)
        return adj_matrix

In [16]:
get_graph = GetGraph()
print(batch_X.shape)
W = get_graph(torch.tensor(batch_X_aux))
print(W.shape)

W_reshaped = W.reshape([batch_size * num_points, -1])
print(W_reshaped.shape)
print(batch.batch.shape)

(6912, 6)
torch.Size([32, 216, 216])
torch.Size([6912, 216])
torch.Size([6912])


In [17]:
cheb_conv = conv.ChebConv(128, 512, 5)
'''
for i, graph in enumerate(W):
    print(graph.shape)
    edge_index, edge_weight = utils.dense_to_sparse(graph)
'''
edge_index = torch.zeros([32, 2, 46656])
edge_weight = torch.zeros([32, 1, 46656])
edges = [utils.dense_to_sparse(graph) for graph in W]
for i, edge in enumerate(edges):
    print(torch.tensor(edge[0]).shape)
    edge_index[i] = torch.tensor(edge[0])
    edge_weight[i] = torch.tensor(edge[1])
edge_index = edge_index.reshape(2, 32 * 46656)
edge_weight = edge_weight.reshape(32 * 46656)

print(edge_index.dtype)
print(edge_weight.shape)

torch.Size([2, 46656])
torch.Size([2, 46656])
torch.Size([2, 46656])
torch.Size([2, 46656])
torch.Size([2, 46656])
torch.Size([2, 46656])
torch.Size([2, 46656])
torch.Size([2, 46656])
torch.Size([2, 46656])
torch.Size([2, 46656])
torch.Size([2, 46656])
torch.Size([2, 46656])
torch.Size([2, 46656])
torch.Size([2, 46656])
torch.Size([2, 46656])
torch.Size([2, 46656])
torch.Size([2, 46656])
torch.Size([2, 46656])
torch.Size([2, 46656])
torch.Size([2, 46656])
torch.Size([2, 46656])
torch.Size([2, 46656])
torch.Size([2, 46656])
torch.Size([2, 46656])
torch.Size([2, 46656])
torch.Size([2, 46656])
torch.Size([2, 46656])
torch.Size([2, 46656])
torch.Size([2, 46656])
torch.Size([2, 46656])
torch.Size([2, 46656])
torch.Size([2, 46656])
torch.float32
torch.Size([1492992])


  print(torch.tensor(edge[0]).shape)
  edge_index[i] = torch.tensor(edge[0])
  edge_weight[i] = torch.tensor(edge[1])


In [18]:
edge_index = torch.tensor(edge_index, dtype=torch.long)
edge_weight = torch.tensor(edge_weight, dtype=torch.float)
batch_X = torch.tensor(batch_X, dtype=torch.float)
X = torch.randn([32, 216, 128])
out = cheb_conv(X, edge_index=edge_index, edge_weight=edge_weight, batch=batch.batch)
print("out: ", out.shape)
W = get_graph(out)
print("W:   ", W.shape)


out:  torch.Size([32, 216, 512])
W:    torch.Size([32, 216, 216])


  edge_index = torch.tensor(edge_index, dtype=torch.long)
  edge_weight = torch.tensor(edge_weight, dtype=torch.float)
