In [1]:
%load_ext autoreload
%autoreload 2

In [2]:
import os
import sys

sys.path.append('..')

In [3]:
import copy
import math
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F

## DGCNN branch

In [4]:
from dgcnn.model import get_graph_feature


class DGCNNFeatureExtractor(nn.Module):
    def __init__(self, args: dict):
        '''
        args must contain:
            'k'
            'emb_dims'
        '''
        super(DGCNNFeatureExtractor, self).__init__()
        self.args = args
        self.k = args['k']
        
        self.bn1 = nn.BatchNorm2d(64)
        self.bn2 = nn.BatchNorm2d(64)
        self.bn3 = nn.BatchNorm2d(128)
        self.bn4 = nn.BatchNorm2d(256)
        self.bn5 = nn.BatchNorm1d(args['emb_dims'])

        self.conv1 = nn.Sequential(nn.Conv2d(6, 64, kernel_size=1, bias=False),
                                   self.bn1,
                                   nn.LeakyReLU(negative_slope=0.2))
        self.conv2 = nn.Sequential(nn.Conv2d(64*2, 64, kernel_size=1, bias=False),
                                   self.bn2,
                                   nn.LeakyReLU(negative_slope=0.2))
        self.conv3 = nn.Sequential(nn.Conv2d(64*2, 128, kernel_size=1, bias=False),
                                   self.bn3,
                                   nn.LeakyReLU(negative_slope=0.2))
        self.conv4 = nn.Sequential(nn.Conv2d(128*2, 256, kernel_size=1, bias=False),
                                   self.bn4,
                                   nn.LeakyReLU(negative_slope=0.2))
        self.conv5 = nn.Sequential(nn.Conv1d(512, args['emb_dims'], kernel_size=1, bias=False),
                                   self.bn5,
                                   nn.LeakyReLU(negative_slope=0.2))

        self.feat_dim = args['emb_dims'] * 2


    def forward(self, x):
        batch_size = x.size(0)
        x = get_graph_feature(x, k=self.k)
        x = self.conv1(x)
        x1 = x.max(dim=-1, keepdim=False)[0]

        x = get_graph_feature(x1, k=self.k)
        x = self.conv2(x)
        x2 = x.max(dim=-1, keepdim=False)[0]

        x = get_graph_feature(x2, k=self.k)
        x = self.conv3(x)
        x3 = x.max(dim=-1, keepdim=False)[0]

        x = get_graph_feature(x3, k=self.k)
        x = self.conv4(x)
        x4 = x.max(dim=-1, keepdim=False)[0]

        x = torch.cat((x1, x2, x3, x4), dim=1)

        x = self.conv5(x)
        x1 = F.adaptive_max_pool1d(x, 1).view(batch_size, -1)
        x2 = F.adaptive_avg_pool1d(x, 1).view(batch_size, -1)
        x = torch.cat((x1, x2), 1)

        return x

In [5]:
dgcnn_args = {
    'k': 5,
    'emb_dims': 128,
}

dgcnn_feat_ex = DGCNNFeatureExtractor(dgcnn_args)
dgcnn_feat_ex.eval()
dgcnn_feat_ex, dgcnn_feat_ex(torch.randn((8,3,1024))).shape

(DGCNNFeatureExtractor(
   (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
   (bn2): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
   (bn3): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
   (bn4): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
   (bn5): BatchNorm1d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
   (conv1): Sequential(
     (0): Conv2d(6, 64, kernel_size=(1, 1), stride=(1, 1), bias=False)
     (1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
     (2): LeakyReLU(negative_slope=0.2)
   )
   (conv2): Sequential(
     (0): Conv2d(128, 64, kernel_size=(1, 1), stride=(1, 1), bias=False)
     (1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
     (2): LeakyReLU(negative_slope=0.2)
   )
   (conv3): Sequential(
     (0): Conv2d(128, 128, kernel_siz

## MeshCNN branch

In [6]:
from meshcnn.models.networks import MResConv, get_norm_args, get_norm_layer
from meshcnn.models.layers.mesh_pool import MeshPool
from meshcnn.models.layers.mesh import Mesh


class MeshCNNFeatureExtractor(nn.Module):
    def __init__(self, args: dict):
        '''
        args must contain:
            'nf0': int
                num input channels (5 for the usual MeshCNN initial edge features)
                Corresponds to "opt.input_nc" in original code, with no default (inferred from dataset)

            'conv_res': list of ints
                num out channels (i.e. filters) for each meshconv layer
                Corresponds to "opt.ncf" in original code, with default [16, 32, 32]

            'input_res': int
                num input edges (we take only this many edges from each input mesh)
                Corresponds to "opt.ninput_edges" in original code, with default 750

            'pool_res': list of ints
                num edges to keep after each meshpool layer
                Corresponds to "opt.pool_res" in original code, with default [1140, 780, 580] 

            'norm': str, one of ['batch', 'instance', 'group', 'none']
                type of norm layer to use
                Corresponds to "opt.norm" in original code, with default 'batch'

            'num_groups': int
                num of groups for groupnorm
                Corresponds to "opt.num_groups" in original code, with default 16

            'nresblocks': int
                num res blocks in each mresconv
                Corresponds to "opt.resblocks" in original code, with default 0
        '''
        super(MeshCNNFeatureExtractor, self).__init__()
        self.k = [args['nf0']] + args['conv_res']
        self.res = [args['input_res']] + args['pool_res']

        norm_layer = get_norm_layer(norm_type=args['norm'], num_groups=args['num_groups'])
        norm_args = get_norm_args(norm_layer, self.k[1:])

        for i, ki in enumerate(self.k[:-1]):
            setattr(self, 'conv{}'.format(i), MResConv(ki, self.k[i + 1], args['nresblocks']))
            setattr(self, 'norm{}'.format(i), norm_layer(**norm_args[i]))
            setattr(self, 'pool{}'.format(i), MeshPool(self.res[i + 1]))


        self.gp = nn.AvgPool1d(self.res[-1])
        # self.gp = nn.MaxPool1d(self.res[-1])

        self.feat_dim = self.k[-1]

    def forward(self, x, mesh):

        for i in range(len(self.k) - 1):
            x = getattr(self, 'conv{}'.format(i))(x, mesh)
            x = F.relu(getattr(self, 'norm{}'.format(i))(x))
            x = getattr(self, 'pool{}'.format(i))(x, mesh)

        x = self.gp(x)
        x = x.view(-1, self.k[-1])

        return x


In [7]:
'''
meshcnn data loader item format - dict, with keys:
    'mesh': Mesh class instance
    'label': Output class label
    'edge_features': Features extracted using extract_features() of the Mesh object above,
                     but PADDED TO ninput_edges AND NORMALIZED BY MEAN&STD OF DATA
'''

mesh = Mesh(
    file=r'C:\Academic\GT - MSCS\Sem II - Spring 2022\CS 7643 - DL\Project\src\dgcnn\pytorch\data\shrec_16\armadillo\test\T55.obj',
    opt=None, export_folder=None)
test_data = {
    'mesh': mesh,
    'label': 0,
    'edge_features': mesh.extract_features(),
}
test_data, test_data['edge_features'].shape

({'mesh': <meshcnn.models.layers.mesh.Mesh at 0x299a44358c8>,
  'label': 0,
  'edge_features': array([[1.57387046, 1.57369169, 1.57342915, ..., 1.57413897, 1.57166371,
          1.57090315],
         [1.35346367, 1.31874424, 1.4147537 , ..., 1.28607998, 1.41143283,
          1.52148169],
         [1.46005855, 1.49583822, 1.47602932, ..., 1.35379826, 1.48621636,
          1.53421813],
         [1.01332234, 0.83147991, 0.89266225, ..., 1.33598894, 0.85190091,
          0.65474185],
         [1.46880606, 1.35175648, 1.14274885, ..., 1.6005793 , 1.03006614,
          0.68757185]])},
 (5, 750))

In [8]:
test_data['edge_features'][None,...,None].shape, [test_data['mesh']]

((1, 5, 750, 1), [<meshcnn.models.layers.mesh.Mesh at 0x299a44358c8>])

In [9]:
meshcnn_args = {
    'nf0': test_data['edge_features'].shape[0],
    'conv_res': [64, 128, 256, 256],
    'input_res': test_data['edge_features'].shape[1],
    'pool_res': [600, 450, 300, 180],
    'norm': 'group',
    'num_groups': 16,
    'nresblocks': 1,
}

meshcnn_feat_ex = MeshCNNFeatureExtractor(meshcnn_args)
meshcnn_feat_ex.eval()
meshcnn_feat_ex, meshcnn_feat_ex(
    torch.from_numpy(test_data['edge_features'][None,...,None]).float().to('cpu'),
    [test_data['mesh']]).shape

(MeshCNNFeatureExtractor(
   (conv0): MResConv(
     (conv0): MeshConv(
       (conv): Conv2d(5, 64, kernel_size=(1, 5), stride=(1, 1), bias=False)
     )
     (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
     (conv1): MeshConv(
       (conv): Conv2d(64, 64, kernel_size=(1, 5), stride=(1, 1), bias=False)
     )
   )
   (norm0): GroupNorm(16, 64, eps=1e-05, affine=True)
   (pool0): MeshPool()
   (conv1): MResConv(
     (conv0): MeshConv(
       (conv): Conv2d(64, 128, kernel_size=(1, 5), stride=(1, 1), bias=False)
     )
     (bn1): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
     (conv1): MeshConv(
       (conv): Conv2d(128, 128, kernel_size=(1, 5), stride=(1, 1), bias=False)
     )
   )
   (norm1): GroupNorm(16, 128, eps=1e-05, affine=True)
   (pool1): MeshPool()
   (conv2): MResConv(
     (conv0): MeshConv(
       (conv): Conv2d(128, 256, kernel_size=(1, 5), stride=(1, 1), bias=False)
     )
     (bn1): Ba

## DGCNN branch on MeshCNN style data

In [10]:
type(mesh.vs), mesh.vs.dtype, mesh.vs.shape, torch.from_numpy(mesh.vs.T[None,...]).shape

(numpy.ndarray, dtype('float64'), (252, 3), torch.Size([1, 3, 252]))

In [11]:
dgcnn_feat_ex(
    torch.from_numpy(mesh.vs.T[None,...]).float()
).shape

torch.Size([1, 256])

## Combined network (pre-classification)

In [12]:
class CombinedFeatureExtractor(nn.Module):
    def __init__(self, dgcnn_args: dict, meshcnn_args: dict):
        super(CombinedFeatureExtractor, self).__init__()

        self.dgcnn_branch = DGCNNFeatureExtractor(dgcnn_args)
        self.meshcnn_branch = MeshCNNFeatureExtractor(meshcnn_args)

        self.feat_dim = self.dgcnn_branch.feat_dim + self.meshcnn_branch.feat_dim


    def forward(self, vertex_input_batch, edge_input_batch, mesh_batch):
        vertex_based_feats = self.dgcnn_branch(vertex_input_batch)
        edge_based_feats = self.meshcnn_branch(edge_input_batch, mesh_batch)

        out = torch.cat([vertex_based_feats, edge_based_feats], dim=-1)
        return out

In [13]:
vertex_input_batch = torch.from_numpy(mesh.vs.T[None,...]).float()
edge_input_batch = torch.from_numpy(test_data['edge_features'][None,...,None]).float()
mesh_batch = [test_data['mesh']]

combined_ex = CombinedFeatureExtractor(dgcnn_args=dgcnn_args, meshcnn_args=meshcnn_args)
combined_ex(vertex_input_batch, edge_input_batch, mesh_batch).shape

torch.Size([1, 512])

## Full network

In [14]:
class CombinedMeshClassifier(nn.Module):
    def __init__(self, classifier_args: dict, dgcnn_args: dict, meshcnn_args: dict):
        super(CombinedMeshClassifier, self).__init__()

        self.feat_ex = CombinedFeatureExtractor(dgcnn_args=dgcnn_args, meshcnn_args=meshcnn_args)
        self.output_block = nn.Sequential(
            nn.Linear(in_features=self.feat_ex.feat_dim, out_features=classifier_args['out_block_hidden_dim']),
            nn.ReLU(),
            nn.Linear(in_features=classifier_args['out_block_hidden_dim'], out_features=classifier_args['out_num_classes'])
        )

    def forward(self, vertex_input_batch, edge_input_batch, mesh_batch):
        combined_feats = self.feat_ex(vertex_input_batch, edge_input_batch, mesh_batch)
        out = self.output_block(combined_feats)
        return out

In [15]:
classifier_args = {
    'out_block_hidden_dim': 1024,
    'out_num_classes': 30
}

classifier = CombinedMeshClassifier(
    classifier_args=classifier_args,
    dgcnn_args=dgcnn_args,
    meshcnn_args=meshcnn_args,
)

classifier(vertex_input_batch, edge_input_batch, mesh_batch).shape

torch.Size([1, 30])

## Data loading

In [16]:
import pickle
from torch.utils.data import Dataset

from meshcnn.util.util import is_mesh_file, pad

data_root_dir = r'C:\Academic\GT - MSCS\Sem II - Spring 2022\CS 7643 - DL\Project\src\dgcnn\pytorch\data\shrec_16'

In [35]:
class SHREC16(Dataset):
    def __init__(self, partition, device, opt: dict):
        '''
        opt structure:
            'ninput_edges': int, num edges to use for meshcnn (will pad if higher than actual)
            'num_points': int, num verts to use for dgcnn (has to be at most the actual num verts)
            'dataroot': str
        '''
        super(SHREC16, self).__init__()
        self.partition = partition
        self.device = device

        self.ninput_edges = opt['ninput_edges']
        self.num_points = opt['num_points']
        self.root = opt['dataroot']
        self.dir = os.path.join(self.root)
        self.classes, self.class_to_idx = self.find_classes(self.dir)
        self.paths = self.make_dataset_by_class(self.dir, self.class_to_idx, partition)
        self.nclasses = len(self.classes)
        self.size = len(self.paths)

        self.mean = 0
        self.std = 1
        self.get_mean_std() # init self.mean, self.std, self.ninput_channels

    def __getitem__(self, index):
        path = self.paths[index][0]
        label = self.paths[index][1]
        mesh = Mesh(file=path, opt=None, hold_history=False, export_folder=None)
        pointcloud = mesh.vs[:self.num_points].T
        meta = {'mesh': mesh, 'label': label, 'pointcloud': pointcloud}

        edge_features = mesh.extract_features()
        edge_features = pad(edge_features, self.ninput_edges)
        meta['edge_features'] = (edge_features - self.mean) / self.std
        return meta

    def __len__(self):
        return self.size


    def get_mean_std(self):
        """ Computes Mean and Standard Deviation from Training Data
        If mean/std file doesn't exist, will compute one
        :returns
        mean: N-dimensional mean
        std: N-dimensional standard deviation
        ninput_channels: N
        (here N=5)
        """

        mean_std_cache = os.path.join(self.root, 'mean_std_cache.pkl')
        if not os.path.isfile(mean_std_cache):
            print('computing mean std from train data...')
            mean, std = np.array(0), np.array(0)
            for i, data in enumerate(self):
                if i % 5 == 0:
                    print('{} of {}'.format(i, self.size))
                features = data['edge_features']
                mean = mean + features.mean(axis=1)
                std = std + features.std(axis=1)
            mean = mean / (i + 1)
            std = std / (i + 1)
            transform_dict = {'mean': mean[:, np.newaxis], 'std': std[:, np.newaxis],
                              'ninput_channels': len(mean)}
            with open(mean_std_cache, 'wb') as f:
                pickle.dump(transform_dict, f)
            print('saved: ', mean_std_cache)

        # open mean / std from file
        with open(mean_std_cache, 'rb') as f:
            transform_dict = pickle.load(f)
            print('loaded mean / std from cache')
            self.mean = transform_dict['mean']
            self.std = transform_dict['std']
            self.ninput_channels = transform_dict['ninput_channels']

    # this is when the folders are organized by class...
    @staticmethod
    def find_classes(dir):
        classes = [d for d in os.listdir(dir) if os.path.isdir(os.path.join(dir, d))]
        classes.sort()
        class_to_idx = {classes[i]: i for i in range(len(classes))}
        return classes, class_to_idx

    @staticmethod
    def make_dataset_by_class(dir, class_to_idx, partition):
        meshes = []
        dir = os.path.expanduser(dir)
        for target in sorted(os.listdir(dir)):
            d = os.path.join(dir, target)
            if not os.path.isdir(d):
                continue
            for root, _, fnames in sorted(os.walk(d)):
                for fname in sorted(fnames):
                    if is_mesh_file(fname) and (root.count(partition)==1):
                        path = os.path.join(root, fname)
                        item = (path, class_to_idx[target])
                        meshes.append(item)
        return meshes

In [43]:
import functools

def collate_fn(batch, device, is_train):
    """Creates mini-batch tensors
    We should build custom collate_fn rather than using default collate_fn
    """
    meta = {}
    keys = batch[0].keys()
    for key in keys:
        meta.update({key: np.array([d[key] for d in batch])})

    input_edge_features = torch.from_numpy(meta['edge_features']).float()
    pointcloud = torch.from_numpy(meta['pointcloud']).float()
    label = torch.from_numpy(meta['label']).long()
    meta['edge_features'] = input_edge_features.to(device).requires_grad_(is_train)
    meta['pointcloud'] = pointcloud.to(device).requires_grad_(is_train)
    meta['label'] = label.to(device)
    # meta['mesh'] already contains the reqd list of meshes
    return meta


class DataLoader:
    """multi-threaded data loading"""

    def __init__(self, partition, opt: dict):
        '''
        opt structure:
            'gpu_ids': list of ints, or None (for cpu)
            'batch_size': int (default: 16)
            'max_dataset_size': int (default: inf)
            'shuffle': bool. Whether to shuffle or not
            'num_threads': int

            'dataset_opt': dict (i.e. nested options), with structure as mentioned in class SHREC16
        '''
        device = torch.device('cuda:{}'.format(opt['gpu_ids'][0])) if opt['gpu_ids'] else torch.device('cpu')
        self.dataset = SHREC16(partition, device, opt['dataset_opt'])

        self.batch_size = opt['batch_size']
        self.max_dataset_size = opt['max_dataset_size']
        self.dataloader = torch.utils.data.DataLoader(
            self.dataset,
            batch_size=self.batch_size,
            shuffle=opt['shuffle'],
            num_workers=opt['num_threads'],
            collate_fn=functools.partial(
                collate_fn,
                device=device, is_train=(partition=='train')
            )
        )

    def __len__(self):
        return min(len(self.dataset), self.max_dataset_size)

    def __iter__(self):
        for i, data in enumerate(self.dataloader):
            if i * self.batch_size >= self.max_dataset_size:
                break
            yield data

In [44]:
dataloader_args = {
    'gpu_ids': None,
    'batch_size': 16,
    'max_dataset_size': np.inf,
    'shuffle': True,
    'num_threads': 0,
    'dataset_opt': {
        'ninput_edges': 750,
        'num_points': 250,
        'dataroot': data_root_dir,
    },
}

train_dataloader = DataLoader('train', dataloader_args)
test_dataloader = DataLoader('test', dataloader_args)
len(train_dataloader), len(test_dataloader)

loaded mean / std from cache
loaded mean / std from cache


(480, 120)

In [49]:
for train_data, test_data in zip(train_dataloader, test_dataloader):
    print(train_data['pointcloud'].shape, test_data['edge_features'].shape)

    print(classifier(train_data['pointcloud'], train_data['edge_features'], train_data['mesh']).shape)
    print(classifier(test_data['pointcloud'], test_data['edge_features'], test_data['mesh']).shape)

torch.Size([16, 3, 250]) torch.Size([16, 5, 750])
torch.Size([16, 30])
torch.Size([16, 30])
torch.Size([16, 3, 250]) torch.Size([16, 5, 750])
torch.Size([16, 30])
torch.Size([16, 30])
torch.Size([16, 3, 250]) torch.Size([16, 5, 750])
torch.Size([16, 30])
torch.Size([16, 30])
torch.Size([16, 3, 250]) torch.Size([16, 5, 750])
torch.Size([16, 30])
torch.Size([16, 30])
torch.Size([16, 3, 250]) torch.Size([16, 5, 750])
torch.Size([16, 30])
torch.Size([16, 30])
torch.Size([16, 3, 250]) torch.Size([16, 5, 750])
torch.Size([16, 30])
torch.Size([16, 30])
torch.Size([16, 3, 250]) torch.Size([16, 5, 750])
torch.Size([16, 30])
torch.Size([16, 30])
torch.Size([16, 3, 250]) torch.Size([8, 5, 750])
torch.Size([16, 30])
torch.Size([8, 30])
