In [1]:
%load_ext autoreload
%autoreload 2

In [2]:
import os
import sys

sys.path.append('..')

In [3]:
import copy
import math
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F

## DGCNN branch

In [4]:
from dgcnn.model import get_graph_feature


class DGCNNFeatureExtractor(nn.Module):
    def __init__(self, args: dict):
        '''
        args must contain:
            'k'
            'emb_dims'
        '''
        super(DGCNNFeatureExtractor, self).__init__()
        self.args = args
        self.k = args['k']
        
        self.bn1 = nn.BatchNorm2d(64)
        self.bn2 = nn.BatchNorm2d(64)
        self.bn3 = nn.BatchNorm2d(128)
        self.bn4 = nn.BatchNorm2d(256)
        self.bn5 = nn.BatchNorm1d(args['emb_dims'])

        self.conv1 = nn.Sequential(nn.Conv2d(6, 64, kernel_size=1, bias=False),
                                   self.bn1,
                                   nn.LeakyReLU(negative_slope=0.2))
        self.conv2 = nn.Sequential(nn.Conv2d(64*2, 64, kernel_size=1, bias=False),
                                   self.bn2,
                                   nn.LeakyReLU(negative_slope=0.2))
        self.conv3 = nn.Sequential(nn.Conv2d(64*2, 128, kernel_size=1, bias=False),
                                   self.bn3,
                                   nn.LeakyReLU(negative_slope=0.2))
        self.conv4 = nn.Sequential(nn.Conv2d(128*2, 256, kernel_size=1, bias=False),
                                   self.bn4,
                                   nn.LeakyReLU(negative_slope=0.2))
        self.conv5 = nn.Sequential(nn.Conv1d(512, args['emb_dims'], kernel_size=1, bias=False),
                                   self.bn5,
                                   nn.LeakyReLU(negative_slope=0.2))


    def forward(self, x):
        batch_size = x.size(0)
        x = get_graph_feature(x, k=self.k)
        x = self.conv1(x)
        x1 = x.max(dim=-1, keepdim=False)[0]

        x = get_graph_feature(x1, k=self.k)
        x = self.conv2(x)
        x2 = x.max(dim=-1, keepdim=False)[0]

        x = get_graph_feature(x2, k=self.k)
        x = self.conv3(x)
        x3 = x.max(dim=-1, keepdim=False)[0]

        x = get_graph_feature(x3, k=self.k)
        x = self.conv4(x)
        x4 = x.max(dim=-1, keepdim=False)[0]

        x = torch.cat((x1, x2, x3, x4), dim=1)

        x = self.conv5(x)
        x1 = F.adaptive_max_pool1d(x, 1).view(batch_size, -1)
        x2 = F.adaptive_avg_pool1d(x, 1).view(batch_size, -1)
        x = torch.cat((x1, x2), 1)

        return x

In [5]:
dgcnn_args = {
    'k': 5,
    'emb_dims': 128,
}

dgcnn_feat_ex = DGCNNFeatureExtractor(dgcnn_args)
dgcnn_feat_ex.eval()
dgcnn_feat_ex, dgcnn_feat_ex(torch.randn((8,3,1024))).shape

(DGCNNFeatureExtractor(
   (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
   (bn2): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
   (bn3): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
   (bn4): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
   (bn5): BatchNorm1d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
   (conv1): Sequential(
     (0): Conv2d(6, 64, kernel_size=(1, 1), stride=(1, 1), bias=False)
     (1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
     (2): LeakyReLU(negative_slope=0.2)
   )
   (conv2): Sequential(
     (0): Conv2d(128, 64, kernel_size=(1, 1), stride=(1, 1), bias=False)
     (1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
     (2): LeakyReLU(negative_slope=0.2)
   )
   (conv3): Sequential(
     (0): Conv2d(128, 128, kernel_siz

## MeshCNN branch

In [6]:
from meshcnn.models.networks import MResConv, get_norm_args, get_norm_layer
from meshcnn.models.layers.mesh_pool import MeshPool
from meshcnn.models.layers.mesh import Mesh


class MeshCNNFeatureExtractor(nn.Module):
    def __init__(self, args: dict):
        '''
        args must contain:
            'nf0': int
                num input channels (5 for the usual MeshCNN initial edge features)
                Corresponds to "opt.input_nc" in original code, with no default (inferred from dataset)

            'conv_res': list of ints
                num out channels (i.e. filters) for each meshconv layer
                Corresponds to "opt.ncf" in original code, with default [16, 32, 32]

            'input_res': int
                num input edges (we take only this many edges from each input mesh)
                Corresponds to "opt.ninput_edges" in original code, with default 750

            'pool_res': list of ints
                num edges to keep after each meshpool layer
                Corresponds to "opt.pool_res" in original code, with default [1140, 780, 580] 

            'norm': str, one of ['batch', 'instance', 'group', 'none']
                type of norm layer to use
                Corresponds to "opt.norm" in original code, with default 'batch'

            'num_groups': int
                num of groups for groupnorm
                Corresponds to "opt.num_groups" in original code, with default 16

            'nresblocks': int
                num res blocks in each mresconv
                Corresponds to "opt.resblocks" in original code, with default 0
        '''
        super(MeshCNNFeatureExtractor, self).__init__()
        self.k = [args['nf0']] + args['conv_res']
        self.res = [args['input_res']] + args['pool_res']

        norm_layer = get_norm_layer(norm_type=args['norm'], num_groups=args['num_groups'])
        norm_args = get_norm_args(norm_layer, self.k[1:])

        for i, ki in enumerate(self.k[:-1]):
            setattr(self, 'conv{}'.format(i), MResConv(ki, self.k[i + 1], args['nresblocks']))
            setattr(self, 'norm{}'.format(i), norm_layer(**norm_args[i]))
            setattr(self, 'pool{}'.format(i), MeshPool(self.res[i + 1]))


        self.gp = nn.AvgPool1d(self.res[-1])
        # self.gp = nn.MaxPool1d(self.res[-1])

    def forward(self, x, mesh):

        for i in range(len(self.k) - 1):
            x = getattr(self, 'conv{}'.format(i))(x, mesh)
            x = F.relu(getattr(self, 'norm{}'.format(i))(x))
            x = getattr(self, 'pool{}'.format(i))(x, mesh)

        x = self.gp(x)
        x = x.view(-1, self.k[-1])

        return x


In [7]:
'''
meshcnn data loader item format - dict, with keys:
    'mesh': Mesh class instance
    'label': Output class label
    'edge_features': Features extracted using extract_features() of the Mesh object above,
                     but PADDED TO ninput_edges AND NORMALIZED BY MEAN&STD OF DATA
'''

mesh = Mesh(
    file=r'C:\Academic\GT - MSCS\Sem II - Spring 2022\CS 7643 - DL\Project\src\dgcnn\pytorch\data\shrec_16\armadillo\test\T55.obj',
    opt=None, export_folder=None)
test_data = {
    'mesh': mesh,
    'label': 0,
    'edge_features': mesh.extract_features(),
}
test_data, test_data['edge_features'].shape

({'mesh': <meshcnn.models.layers.mesh.Mesh at 0x225f257ce48>,
  'label': 0,
  'edge_features': array([[1.57387046, 1.57369169, 1.57342915, ..., 1.57413897, 1.57166371,
          1.57090315],
         [1.35346367, 1.31874424, 1.4147537 , ..., 1.28607998, 1.41143283,
          1.52148169],
         [1.46005855, 1.49583822, 1.47602932, ..., 1.35379826, 1.48621636,
          1.53421813],
         [1.01332234, 0.83147991, 0.89266225, ..., 1.33598894, 0.85190091,
          0.65474185],
         [1.46880606, 1.35175648, 1.14274885, ..., 1.6005793 , 1.03006614,
          0.68757185]])},
 (5, 750))

In [8]:
test_data['edge_features'][None,...,None].shape, [test_data['mesh']]

((1, 5, 750, 1), [<meshcnn.models.layers.mesh.Mesh at 0x225f257ce48>])

In [9]:
meshcnn_args = {
    'nf0': test_data['edge_features'].shape[0],
    'conv_res': [64, 128, 256, 256],
    'input_res': test_data['edge_features'].shape[1],
    'pool_res': [600, 450, 300, 180],
    'norm': 'group',
    'num_groups': 16,
    'nresblocks': 1,
}

meshcnn_feat_ex = MeshCNNFeatureExtractor(meshcnn_args)
meshcnn_feat_ex.eval()
meshcnn_feat_ex, meshcnn_feat_ex(
    torch.from_numpy(test_data['edge_features'][None,...,None]).float().to('cpu'),
    [test_data['mesh']]).shape

(MeshCNNFeatureExtractor(
   (conv0): MResConv(
     (conv0): MeshConv(
       (conv): Conv2d(5, 64, kernel_size=(1, 5), stride=(1, 1), bias=False)
     )
     (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
     (conv1): MeshConv(
       (conv): Conv2d(64, 64, kernel_size=(1, 5), stride=(1, 1), bias=False)
     )
   )
   (norm0): GroupNorm(16, 64, eps=1e-05, affine=True)
   (pool0): MeshPool()
   (conv1): MResConv(
     (conv0): MeshConv(
       (conv): Conv2d(64, 128, kernel_size=(1, 5), stride=(1, 1), bias=False)
     )
     (bn1): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
     (conv1): MeshConv(
       (conv): Conv2d(128, 128, kernel_size=(1, 5), stride=(1, 1), bias=False)
     )
   )
   (norm1): GroupNorm(16, 128, eps=1e-05, affine=True)
   (pool1): MeshPool()
   (conv2): MResConv(
     (conv0): MeshConv(
       (conv): Conv2d(128, 256, kernel_size=(1, 5), stride=(1, 1), bias=False)
     )
     (bn1): Ba