Final tests for the journal review. What is tested:
    
    - get/plot more params (FLOPS, model size etc.)
    - PCD vis: GT vs Pred
    - Test for all noises combined
    

In [1]:
# Imports:
import time
import imports
import torch
import os
from torch_geometric.datasets import ShapeNet
from torch_geometric.loader import DenseDataLoader
from torch_geometric.transforms import FixedPoints, Compose, NormalizeScale, NormalizeRotation, RandomRotate
from FilteredShapenetDataset import FilteredShapeNet, ShapeNetCustom
from RGCNNSegmentation import seg_model
import numpy as np

import pickle

from pathlib import Path
from collections import defaultdict

from utils import BoundingBoxRotate, label_to_cat
from utils import seg_classes

from TestModel import ModelTester

Jupyter environment detected. Enabling Open3D WebVisualizer.
[Open3D INFO] WebRTC GUI backend enabled.
[Open3D INFO] WebRTCWindowSystem: HTTP handshake server disabled.


In [2]:
def get_n_params(model):
    pp=0
    for p in list(model.parameters()):
        nn=1
        for s in list(p.size()):
            nn = nn*s
        pp += nn
    return pp

In [3]:
# Load the dataset 
dataset_name = "Original_2048"
dataset = ShapeNet(root=f"{imports.dataset_path}/Journal/ShapeNet/", split="test", transform=FixedPoints(2048), categories=seg_classes)
print(dataset[0])

Data(x=[2048, 3], y=[2048], pos=[2048, 3], category=[1])


In [4]:
model_names = ["2048_seg_clean.pt", "2048_seg_bb.pt", "2048_seg_rrbb.pt", "2048_seg_eig.pt", "2048_seg_gauss_rr_bb.pt", "2048_seg_gauss_rr_eig.pt"]

num_points = 2048
input_dim  = 22

F = [128, 512, 1024]  # Outputs size of convolutional filter.
K = [6, 5, 3]         # Polynomial orders.
M = [512, 128, 50]

for model in model_names:
    net = seg_model(num_points, F, K, M, input_dim, dropout=0.2, reg_prior=False)
    net.load_state_dict(torch.load(f"{imports.curr_path}/{model}"))
    net.eval()
    print(f"\nModel: {model}")
    print(get_n_params(net))
    print(net)
    


Model: 2048_seg_clean.pt
7413666
seg_model(
  (dropout): Dropout(p=0.2, inplace=False)
  (bias_relus): ParameterList(
      (0): Parameter containing: [torch.FloatTensor of size 1x2048x128]
      (1): Parameter containing: [torch.FloatTensor of size 1x2048x512]
      (2): Parameter containing: [torch.FloatTensor of size 1x2048x1024]
      (3): Parameter containing: [torch.FloatTensor of size 1x2048x512]
      (4): Parameter containing: [torch.FloatTensor of size 1x2048x128]
      (5): Parameter containing: [torch.FloatTensor of size 1x2048x50]
  )
  (conv): ModuleList(
    (0): DenseChebConvV2(in_features=22, out_features=128, K=6, bias=True)
    (1): DenseChebConvV2(in_features=128, out_features=512, K=5, bias=True)
    (2): DenseChebConvV2(in_features=512, out_features=1024, K=3, bias=True)
  )
  (batch_norm_list_conv): ModuleList(
    (0): BatchNorm1d(6, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (1): BatchNorm1d(128, eps=1e-05, momentum=0.1, affine=True, t

In [8]:
# Test FLOPS counter:
from fvcore.nn import FlopCountAnalysis
with torch.no_grad():

    class NetWrapper(torch.nn.Module):
        def __init__(self, net):
            super(NetWrapper, self).__init__()
            self.net = net
        
        def forward(self, data):
            return self.net(data[0], data[1])
        

    model = "2048_seg_clean.pt"
    pos = dataset[0].pos.unsqueeze(0)
    norm = dataset[0].x.unsqueeze(0)
    x = torch.cat([pos.type(torch.float32), norm.type(torch.float32)], dim=2)
    cat = dataset[0].category
    net = seg_model(2048, F, K, M, input_dim, dropout=0.2, reg_prior=False)
    net.load_state_dict(torch.load(f"{imports.curr_path}/{model}"))
    net = net.eval()
    out,x , _ = net(x, cat)
    print(out.shape)
    # net_wrapped = NetWrapper(net)


    # data = [x.detach(), cat.detach()]

    # flops = FlopCountAnalysis(net_wrapped, data)
    # print(f"Total flops:       {flops.total()}")
    # print(f"Flops by operator: {flops.by_operator()}")
    # print(f"Flops by module:   {flops.by_module()}")


torch.Size([1, 2048, 50])


Point 2

- GT no rot
- GT rot
- RGCNN no rot
- RGCNN rot
- Ours no rot
- Ours rot

Also thest DGCNN and PointNet: 

- https://github.com/antao97/dgcnn.pytorch/blob/master/main_cls.py


Time complexity:

- Eig - Alex GIT (/ NormalizeRot Pytorch???)
- GRAMM + Model - model dim increases
- BB
- multiview BB - model forward
- multiview Eig - model forward

Breakdown and time the modules of the alg