## (i)  Peforming necessary imports, before writing code.

In [35]:
import os
os.environ['CUDA_LAUNCH_BLOCKING'] = '1'
import time
from data import DataLoader
from models import create_model
from util.writer import Writer
import sys
import argparse
import os
from util import util
import torch
import sys
CUDA_LAUNCH_BLOCKING="1"
from os.path import join
from models import networks
from torch import nn
from util.util import seg_accuracy, print_network
from models.layers.mesh_conv import MeshConv
import functools
from models.layers.mesh_pool import MeshPool
from torch.nn import init
import torch.nn.functional as F
from torchviz import make_dot
from models.layers.mesh_union import MeshUnion


### The MeshPooling Operations: 


(A) Running the MeshPooling operation to obtain the pooled version of the mesh

(B) The pooled mesh is implemented using a heap and a queue data structure deciding whichever has higher priority is implemented first

In [33]:
class MeshPool(nn.Module):
    # Initialize the important parameters
    def __init__(self, target, multi_thread=False):
        super(MeshPool, self).__init__()
        self.__out_target = target
        self.__multi_thread = multi_thread
        self.__fe = None
        self.__updated_fe = None
        self.__meshes = None
        self.__merge_edges = [-1, -1]

    # Call function to obtain the forward pass function.
    def __call__(self, fe, meshes):
        return self.forward(fe, meshes)

    # Forward pass function => for
    def forward(self, fe, meshes):
        self.__updated_fe = [[] for _ in range(len(meshes))]
        pool_threads = []
        self.__fe = fe
        self.__meshes = meshes
        # iterate over batch
        for mesh_index in range(len(meshes)):
            if self.__multi_thread:
                pool_threads.append(Thread(target=self.__pool_main, args=(mesh_index,)))
                pool_threads[-1].start()
            else:
                print('Mesh Index is', mesh_index)
                self.__pool_main(mesh_index)
        if self.__multi_thread:
            for mesh_index in range(len(meshes)):
                pool_threads[mesh_index].join()
        out_features = torch.cat(self.__updated_fe).view(len(meshes), -1, self.__out_target)
        return out_features

    def __pool_main(self, mesh_index):
        mesh = self.__meshes[mesh_index]
        queue = self.__build_queue(self.__fe[mesh_index, :, :mesh.edges_count], mesh.edges_count)
        # recycle = []
        # last_queue_len = len(queue)
        last_count = mesh.edges_count + 1
        mask = np.ones(mesh.edges_count, dtype=np.bool)
        edge_groups = MeshUnion(mesh.edges_count, self.__fe.device)
        while mesh.edges_count > self.__out_target:
            value, edge_id = heappop(queue)
            edge_id = int(edge_id)
            if mask[edge_id]:
                self.__pool_edge(mesh, edge_id, mask, edge_groups)
        mesh.clean(mask, edge_groups)
        fe = edge_groups.rebuild_features(self.__fe[mesh_index], mask, self.__out_target)
        self.__updated_fe[mesh_index] = fe

    def __pool_edge(self, mesh, edge_id, mask, edge_groups):
        if self.has_boundaries(mesh, edge_id):
            return False
        elif self.__clean_side(mesh, edge_id, mask, edge_groups, 0)\
            and self.__clean_side(mesh, edge_id, mask, edge_groups, 2) \
            and self.__is_one_ring_valid(mesh, edge_id):
            self.__merge_edges[0] = self.__pool_side(mesh, edge_id, mask, edge_groups, 0)
            self.__merge_edges[1] = self.__pool_side(mesh, edge_id, mask, edge_groups, 2)
            mesh.merge_vertices(edge_id)
            mask[edge_id] = False
            MeshPool.__remove_group(mesh, edge_groups, edge_id)
            mesh.edges_count -= 1
            return True
        else:
            return False

    def __clean_side(self, mesh, edge_id, mask, edge_groups, side):
        if mesh.edges_count <= self.__out_target:
            return False
        invalid_edges = MeshPool.__get_invalids(mesh, edge_id, edge_groups, side)
        while len(invalid_edges) != 0 and mesh.edges_count > self.__out_target:
            self.__remove_triplete(mesh, mask, edge_groups, invalid_edges)
            if mesh.edges_count <= self.__out_target:
                return False
            if self.has_boundaries(mesh, edge_id):
                return False
            invalid_edges = self.__get_invalids(mesh, edge_id, edge_groups, side)
        return True

    @staticmethod
    def has_boundaries(mesh, edge_id):
        for edge in mesh.gemm_edges[edge_id]:
            if edge == -1 or -1 in mesh.gemm_edges[edge]:
                return True
        return False


    @staticmethod
    def __is_one_ring_valid(mesh, edge_id):
        v_a = set(mesh.edges[mesh.ve[mesh.edges[edge_id, 0]]].reshape(-1))
        v_b = set(mesh.edges[mesh.ve[mesh.edges[edge_id, 1]]].reshape(-1))
        shared = v_a & v_b - set(mesh.edges[edge_id])
        return len(shared) == 2

    def __pool_side(self, mesh, edge_id, mask, edge_groups, side):
        info = MeshPool.__get_face_info(mesh, edge_id, side)
        key_a, key_b, side_a, side_b, _, other_side_b, _, other_keys_b = info
        self.__redirect_edges(mesh, key_a, side_a - side_a % 2, other_keys_b[0], mesh.sides[key_b, other_side_b])
        self.__redirect_edges(mesh, key_a, side_a - side_a % 2 + 1, other_keys_b[1], mesh.sides[key_b, other_side_b + 1])
        MeshPool.__union_groups(mesh, edge_groups, key_b, key_a)
        MeshPool.__union_groups(mesh, edge_groups, edge_id, key_a)
        mask[key_b] = False
        MeshPool.__remove_group(mesh, edge_groups, key_b)
        mesh.remove_edge(key_b)
        mesh.edges_count -= 1
        return key_a

    @staticmethod
    def __get_invalids(mesh, edge_id, edge_groups, side):
        info = MeshPool.__get_face_info(mesh, edge_id, side)
        key_a, key_b, side_a, side_b, other_side_a, other_side_b, other_keys_a, other_keys_b = info
        shared_items = MeshPool.__get_shared_items(other_keys_a, other_keys_b)
        if len(shared_items) == 0:
            return []
        else:
            assert (len(shared_items) == 2)
            middle_edge = other_keys_a[shared_items[0]]
            update_key_a = other_keys_a[1 - shared_items[0]]
            update_key_b = other_keys_b[1 - shared_items[1]]
            update_side_a = mesh.sides[key_a, other_side_a + 1 - shared_items[0]]
            update_side_b = mesh.sides[key_b, other_side_b + 1 - shared_items[1]]
            MeshPool.__redirect_edges(mesh, edge_id, side, update_key_a, update_side_a)
            MeshPool.__redirect_edges(mesh, edge_id, side + 1, update_key_b, update_side_b)
            MeshPool.__redirect_edges(mesh, update_key_a, MeshPool.__get_other_side(update_side_a), update_key_b, MeshPool.__get_other_side(update_side_b))
            MeshPool.__union_groups(mesh, edge_groups, key_a, edge_id)
            MeshPool.__union_groups(mesh, edge_groups, key_b, edge_id)
            MeshPool.__union_groups(mesh, edge_groups, key_a, update_key_a)
            MeshPool.__union_groups(mesh, edge_groups, middle_edge, update_key_a)
            MeshPool.__union_groups(mesh, edge_groups, key_b, update_key_b)
            MeshPool.__union_groups(mesh, edge_groups, middle_edge, update_key_b)
            return [key_a, key_b, middle_edge]

    @staticmethod
    def __redirect_edges(mesh, edge_a_key, side_a, edge_b_key, side_b):
        mesh.gemm_edges[edge_a_key, side_a] = edge_b_key
        mesh.gemm_edges[edge_b_key, side_b] = edge_a_key
        mesh.sides[edge_a_key, side_a] = side_b
        mesh.sides[edge_b_key, side_b] = side_a

    @staticmethod
    def __get_shared_items(list_a, list_b):
        shared_items = []
        for i in range(len(list_a)):
            for j in range(len(list_b)):
                if list_a[i] == list_b[j]:
                    shared_items.extend([i, j])
        return shared_items

    @staticmethod
    def __get_other_side(side):
        return side + 1 - 2 * (side % 2)

    @staticmethod
    def __get_face_info(mesh, edge_id, side):
        key_a = mesh.gemm_edges[edge_id, side]
        key_b = mesh.gemm_edges[edge_id, side + 1]
        side_a = mesh.sides[edge_id, side]
        side_b = mesh.sides[edge_id, side + 1]
        other_side_a = (side_a - (side_a % 2) + 2) % 4
        other_side_b = (side_b - (side_b % 2) + 2) % 4
        other_keys_a = [mesh.gemm_edges[key_a, other_side_a], mesh.gemm_edges[key_a, other_side_a + 1]]
        other_keys_b = [mesh.gemm_edges[key_b, other_side_b], mesh.gemm_edges[key_b, other_side_b + 1]]
        return key_a, key_b, side_a, side_b, other_side_a, other_side_b, other_keys_a, other_keys_b

    @staticmethod
    def __remove_triplete(mesh, mask, edge_groups, invalid_edges):
        vertex = set(mesh.edges[invalid_edges[0]])
        for edge_key in invalid_edges:
            vertex &= set(mesh.edges[edge_key])
            mask[edge_key] = False
            MeshPool.__remove_group(mesh, edge_groups, edge_key)
        mesh.edges_count -= 3
        vertex = list(vertex)
        assert(len(vertex) == 1)
        mesh.remove_vertex(vertex[0])

    def __build_queue(self, features, edges_count):
        # delete edges with smallest norm
        squared_magnitude = torch.sum(features * features, 0)
        if squared_magnitude.shape[-1] != 1:
            squared_magnitude = squared_magnitude.unsqueeze(-1)
        edge_ids = torch.arange(edges_count, device=squared_magnitude.device, dtype=torch.float32).unsqueeze(-1)
        heap = torch.cat((squared_magnitude, edge_ids), dim=-1).tolist()
        heapify(heap)
        return heap

    @staticmethod
    def __union_groups(mesh, edge_groups, source, target):
        edge_groups.union(source, target)
        mesh.union_groups(source, target)

    @staticmethod
    def __remove_group(mesh, edge_groups, index):
        edge_groups.remove_group(index)
        mesh.remove_group(index)


## (a) This function is called right after Mesh ConvNet


(i) Setattr and Getattr are used to obtain the forward convolution. 

(ii) The setattr first sets the value of the attribute, and getattr obtains that value. 

(iii) MeshPool changes the structure of the mesh completely. 

In [13]:
class MeshConvNet(nn.Module):
    """Global classification task using MeshConvNet.
    """
    def __init__(self, norm_layer, nf0, conv_res, nclasses, input_res, pool_res, fc_n,
                 nresblocks=3):
        print('Norm layer', norm_layer)
        print('Nf0', nf0)
        print('Number of Convolutional Filter', conv_res)
        print('Number of Classes', nclasses)
        print('Number of Input Edges',input_res)
        print('Edge Resolution Across Layers', pool_res)
        print('Number of nodes is penultimate fc layer', fc_n)
        print('Number of Residual Blocks:: Skip Connections', nresblocks)
        print('In Channels Start :', [nf0] + conv_res)
        super(MeshConvNet, self).__init__()
        self.k = [nf0] + conv_res
        self.res = [input_res] + pool_res
        norm_args = get_norm_args(norm_layer, self.k[1:])

        for i, ki in enumerate(self.k[:-1]):
            setattr(self, 'conv{}'.format(i), MResConv(ki, self.k[i + 1], nresblocks))
            setattr(self, 'norm{}'.format(i), norm_layer(**norm_args[i]))
            setattr(self, 'pool{}'.format(i), MeshPool(self.res[i + 1]))


        self.gp = torch.nn.AvgPool1d(self.res[-1])
        # self.gp = torch.nn.MaxPool1d(self.res[-1])
        self.fc1 = nn.Linear(self.k[-1], fc_n)
        self.fc2 = nn.Linear(fc_n, nclasses)

    def forward(self, x, mesh):

        # Mesh convolutional layers.
        for i in range(len(self.k) - 1):
            x = getattr(self, 'conv{}'.format(i))(x, mesh)
            x = F.relu(getattr(self, 'norm{}'.format(i))(x))
            x = getattr(self, 'pool{}'.format(i))(x, mesh)

        # The final fully connected layers of the network. 
        x = self.gp(x)
        x = x.view(-1, self.k[-1])

        x = F.relu(self.fc1(x))
        x = self.fc2(x)
        return x

### (i) Residual Convolutional Layers, the first possible layer

In [14]:
# The Mesh-Residual convolution class. 
class MResConv(nn.Module):
    def __init__(self, in_channels, out_channels, skips=1):
        super(MResConv, self).__init__()
        self.in_channels = in_channels     
        self.out_channels = out_channels
        self.skips = skips
        self.conv0 = MeshConv(self.in_channels, self.out_channels, bias=False)
        for i in range(self.skips):
            setattr(self, 'bn{}'.format(i + 1), nn.BatchNorm2d(self.out_channels))
            setattr(self, 'conv{}'.format(i + 1),
                    MeshConv(self.out_channels, self.out_channels, bias=False))

    def forward(self, x, mesh):
        x = self.conv0(x, mesh)
        x1 = x
        for i in range(self.skips):
            x = getattr(self, 'bn{}'.format(i + 1))(F.relu(x))
            x = getattr(self, 'conv{}'.format(i + 1))(x, mesh)
        x += x1
        x = F.relu(x)
        return x

### (ii) The norm_layer for normalization 

In [15]:
def get_norm_layer(norm_type='instance', num_groups=1):
    if norm_type == 'batch':
        norm_layer = functools.partial(nn.BatchNorm2d, affine=True)
    elif norm_type == 'instance':
        norm_layer = functools.partial(nn.InstanceNorm2d, affine=False)
    elif norm_type == 'group':
        norm_layer = functools.partial(nn.GroupNorm, affine=True, num_groups=num_groups)
    elif norm_type == 'none':
        norm_layer = NoNorm
    else:
        raise NotImplementedError('normalization layer [%s] is not found' % norm_type)
    return norm_layer

In [16]:
def get_norm_args(norm_layer, nfeats_list):
    if hasattr(norm_layer, '__name__') and norm_layer.__name__ == 'NoNorm':
        norm_args = [{'fake': True} for f in nfeats_list]
    elif norm_layer.func.__name__ == 'GroupNorm':
        norm_args = [{'num_channels': f} for f in nfeats_list]
    elif norm_layer.func.__name__ == 'BatchNorm':
        norm_args = [{'num_features': f} for f in nfeats_list]
    else:
        raise NotImplementedError('normalization layer [%s] is not found' % norm_layer.func.__name__)
    return norm_args

### (iii) The initialization of the network, with random parameters.

In [17]:
def init_net(net, init_type, init_gain, gpu_ids):
    if len(gpu_ids) > 0:
        assert(torch.cuda.is_available())
        net.cuda(gpu_ids[0])
        net = net.cuda()
        net = torch.nn.DataParallel(net, gpu_ids)
    if init_type != 'none':
        init_weights(net, init_type, init_gain)
    return net

### (iv) The Initialization algorithm of our network: Weight initalization 

In [18]:
def init_weights(net, init_type, init_gain):
    def init_func(m):
        classname = m.__class__.__name__
        if hasattr(m, 'weight') and (classname.find('Conv') != -1 or classname.find('Linear') != -1):
            if init_type == 'normal':
                init.normal_(m.weight.data, 0.0, init_gain)
            elif init_type == 'xavier':
                init.xavier_normal_(m.weight.data, gain=init_gain)
            elif init_type == 'kaiming':
                init.kaiming_normal_(m.weight.data, a=0, mode='fan_in')
            elif init_type == 'orthogonal':
                init.orthogonal_(m.weight.data, gain=init_gain)
            else:
                raise NotImplementedError('initialization method [%s] is not implemented' % init_type)
        elif classname.find('BatchNorm2d') != -1:
            init.normal_(m.weight.data, 1.0, init_gain)
            init.constant_(m.bias.data, 0.0)
    net.apply(init_func)

### (v)  Which Classifier to Use, Mesh-Conv-Net or Mesh-UNET for segmentation. 

In [19]:
def define_classifier(input_nc, ncf, ninput_edges, nclasses, opt, gpu_ids, arch, init_type, init_gain):
    net = None
    norm_layer = get_norm_layer(norm_type=opt.norm, num_groups=opt.num_groups)
    if arch == 'mconvnet':
        # Defining a convolutional mesh classifier. 
        net = MeshConvNet(norm_layer, input_nc, ncf, nclasses, ninput_edges, opt.pool_res, opt.fc_n,
                          opt.resblocks)
    elif arch == 'meshunet':
        down_convs = [input_nc] + ncf
        up_convs = ncf[::-1] + [nclasses]
        pool_res = [ninput_edges] + opt.pool_res
        net = MeshEncoderDecoder(pool_res, down_convs, up_convs, blocks=opt.resblocks,
                                 transfer_data=True)
    else:
        raise NotImplementedError('Encoder model name [%s] is not recognized' % arch)
    return init_net(net, init_type, init_gain, gpu_ids)

## (b)  MeshConv: The fundamental building block


In [20]:
from torch import nn

In [21]:
class MeshConv(nn.Module):
    """ Computes convolution between edges and 4 incident (1-ring) edge neighbors
    in the forward pass takes:
    x: edge features (Batch x Features x Edges)
    mesh: list of mesh data-structure (len(mesh) == Batch)
    and applies convolution
    """
    def __init__(self, in_channels, out_channels, k=5, bias=True):
        super(MeshConv, self).__init__()
        self.conv = nn.Conv2d(in_channels=in_channels, out_channels=out_channels, kernel_size=(1, k), bias=bias)
        self.k = k

    def __call__(self, edge_f, mesh):
        return self.forward(edge_f, mesh)

    def forward(self, x, mesh):
        print(x.shape)
        x = x.squeeze(-1)
        G = torch.cat([self.pad_gemm(i, x.shape[2], x.device) for i in mesh], 0)
        # build 'neighborhood image' and apply convolution
        G = self.create_GeMM(x, G)
        # Build the neigbourhood and create_GEMM edges 
        x = self.conv(G)
        return x

    def flatten_gemm_inds(self, Gi):
        (b, ne, nn) = Gi.shape
        ne += 1
        batch_n = torch.floor(torch.arange(b * ne, device=Gi.device).float() / ne).view(b, ne)
        add_fac = batch_n * ne
        add_fac = add_fac.view(b, ne, 1)
        add_fac = add_fac.repeat(1, 1, nn)
        # flatten Gi
        Gi = Gi.float() + add_fac[:, 1:, :]
        return Gi

    # Create GEMM features from x and the graph structure G
    def create_GeMM(self, x, Gi):
        """ gathers the edge features (x) with from the 1-ring indices (Gi)
        applys symmetric functions to handle order invariance
        returns a 'fake image' which can use 2d convolution on
        output dimensions: Batch x Channels x Edges x 5
        """
        Gishape = Gi.shape
        # pad the first row of  every sample in batch with zeros
        padding = torch.zeros((x.shape[0], x.shape[1], 1), requires_grad=True, device=x.device)
        # padding = padding.to(x.device)
        x = torch.cat((padding, x), dim=2)
        Gi = Gi + 1 #shift

        # Flatten Gemm indicies. 
        Gi_flat = self.flatten_gemm_inds(Gi)
        Gi_flat = Gi_flat.view(-1).long()
        
        # Odim is shape of X. 
        odim = x.shape
        x = x.permute(0, 2, 1).contiguous()
        x = x.view(odim[0] * odim[2], odim[1])

        # Index = Gi_flat is being selected. 
        
        # Selecting the Gemm indices including current vertices. 
        f = torch.index_select(x, dim=0, index=Gi_flat)
        
        # Bring it back to the original shape and permute it. 
        f = f.view(Gishape[0], Gishape[1], Gishape[2], -1)
        f = f.permute(0, 3, 1, 2)

        # apply the symmetric functions for an equivariant conv
        x_1 = f[:, :, :, 1] + f[:, :, :, 3]
        x_2 = f[:, :, :, 2] + f[:, :, :, 4]
        x_3 = torch.abs(f[:, :, :, 1] - f[:, :, :, 3])
        x_4 = torch.abs(f[:, :, :, 2] - f[:, :, :, 4])
        f = torch.stack([f[:, :, :, 0], x_1, x_2, x_3, x_4], dim=3)
        return f

    def pad_gemm(self, m, xsz, device):
        """ extracts one-ring neighbors (4x) -> m.gemm_edges
        which is of size #edges x 4
        add the edge_id itself to make #edges x 5
        then pad to desired size e.g., xsz x 5
        """
        padded_gemm = torch.tensor(m.gemm_edges, device=device).float()
        padded_gemm = padded_gemm.requires_grad_()
        padded_gemm = torch.cat((torch.arange(m.edges_count, device=device).float().unsqueeze(1), padded_gemm), dim=1)
        # pad using F
        padded_gemm = F.pad(padded_gemm, (0, 0, 0, xsz - m.edges_count), "constant", 0)
        padded_gemm = padded_gemm.unsqueeze(0)
        return padded_gemm


##  The Classifier Model Class

This model contains 
> (i) Data to be processed <br />
> (ii) The classifier definition <br />
> (iii) Backprop algorithm <br />

In [22]:
class ClassifierModel:
    """ Class for training Model weights
    :args opt: structure containing configuration params
    e.g.,
    --dataset_mode -> classification / segmentation)
    --arch -> network type                                                                       
    """
    def __init__(self, opt):
        self.opt = opt
        self.gpu_ids = opt.gpu_ids
        self.is_train = opt.is_train
        self.device = torch.device('cuda:{}'.format(self.gpu_ids[0])) if self.gpu_ids else torch.device('cpu')
        self.save_dir = join(opt.checkpoints_dir, opt.name)
        self.optimizer = None
        self.edge_features = None
        self.labels = None
        self.mesh = None
        self.soft_label = None
        self.loss = None

        #
        self.nclasses = opt.nclasses

        # load/define networks: We import networks from the networks function.

        # Define the network here. 
        self.net = define_classifier(opt.input_nc, opt.ncf, opt.ninput_edges, opt.nclasses, opt,
                                              self.gpu_ids, opt.arch, opt.init_type, opt.init_gain)

        # Check if the network is in training mode or not. 
        self.net.train(self.is_train)

        self.criterion = networks.define_loss(opt).to(self.device)

        if self.is_train:
            self.optimizer = torch.optim.Adam(self.net.parameters(), lr=opt.lr, betas=(opt.beta1, 0.999))
            self.scheduler = networks.get_scheduler(self.optimizer, opt)
            print_network(self.net)

        if not self.is_train or opt.continue_train:
            self.load_network(opt.which_epoch)

    #Set Input as self.set_input  
    def set_input(self, data):
        input_edge_features = torch.from_numpy(data['edge_features']).float()
        labels = torch.from_numpy(data['label']).long()
        self.edge_features = input_edge_features.to(self.device).requires_grad_(self.is_train)
        self.labels = labels.to(self.device)
        self.mesh = data['mesh']
        if self.opt.dataset_mode == 'segmentation' and not self.is_train:
            self.soft_label = torch.from_numpy(data['soft_label'])

    def forward(self):
        out = self.net(self.edge_features, self.mesh)
        return out

    def backward(self, out):
        self.loss = self.criterion(out, self.labels)
        self.loss.backward()

    def optimize_parameters(self):
        self.optimizer.zero_grad()
        out = self.forward()
        self.backward(out)
        self.optimizer.step()


##################

    # Load network, from pretrained .pth file. 
    def load_network(self, which_epoch):
        """load model from disk"""
        save_filename = '%s_net.pth' % which_epoch
        load_path = join(self.save_dir, save_filename)
        net = self.net
        if isinstance(net, torch.nn.DataParallel):
            net = net.module
        print('loading the model from %s' % load_path)
        # PyTorch newer than 0.4 (e.g., built from
        # GitHub source), you can remove str() on self.device
        state_dict = torch.load(load_path, map_location=str(self.device))
        if hasattr(state_dict, '_metadata'):
            del state_dict._metadata
        net.load_state_dict(state_dict)

    # Save the model after training after every epoch.
    def save_network(self, which_epoch):
        """save model to disk"""
        save_filename = '%s_net.pth' % (which_epoch)
        save_path = join(self.save_dir, save_filename)
        if len(self.gpu_ids) > 0 and torch.cuda.is_available():
            torch.save(self.net.module.cpu().state_dict(), save_path)
            self.net.cuda(self.gpu_ids[0])
        else:
            torch.save(self.net.cpu().state_dict(), save_path)

    #Update the learning rate called every once in a while.
    def update_learning_rate(self):
        """update learning rate (called once every epoch)"""
        self.scheduler.step()
        lr = self.optimizer.param_groups[0]['lr']
        print('learning rate = %.7f' % lr)

    # Test the model obtained
    def test(self):
        """tests model
        returns: number correct and total number
        """
        with torch.no_grad():
            out = self.forward()
            # compute number of correct
            pred_class = out.data.max(1)[1]
            label_class = self.labels
            self.export_segmentation(pred_class.cpu())
            correct = self.get_accuracy(pred_class, label_class)
        return correct, len(label_class)

    # Obtain accuracy of prediction.
    def get_accuracy(self, pred, labels):
        """computes accuracy for classification / segmentation """
        if self.opt.dataset_mode == 'classification':
            correct = pred.eq(labels).sum()
        elif self.opt.dataset_mode == 'segmentation':
            correct = seg_accuracy(pred, self.soft_label, self.mesh)
        return correct

    # Export and save segmentation. 
    def export_segmentation(self, pred_seg):
        if self.opt.dataset_mode == 'segmentation':
            for meshi, mesh in enumerate(self.mesh):
                mesh.export_segments(pred_seg[meshi, :])


## (iii) Using Argument Parser with Default Parameters


In [23]:
class BaseOptions:
    def __init__(self):
        self.parser = argparse.ArgumentParser(formatter_class=argparse.ArgumentDefaultsHelpFormatter)
        self.initialized = False

    def initialize(self):
        # data params
        self.parser.add_argument('--dataroot', default = 'datasets/shrec_16', help='path to meshes (should have subfolders train, test)')
        self.parser.add_argument('--dataset_mode', choices={"classification", "segmentation"}, default='classification')
        self.parser.add_argument('--ninput_edges', type=int, default=750, help='# of input edges (will include dummy edges)')
        self.parser.add_argument('--max_dataset_size', type=int, default=float("inf"), help='Maximum number of samples per epoch')
        # network params
        self.parser.add_argument('--batch_size', type=int, default=16, help='input batch size')
        self.parser.add_argument('--arch', type=str, default='mconvnet', help='selects network to use') #todo add choices
        self.parser.add_argument('--resblocks', type=int, default=1, help='# of res blocks')
        self.parser.add_argument('--fc_n', type=int, default=100, help='# between fc and nclasses') #todo make generic
        self.parser.add_argument('--ncf', nargs='+', default=[64, 128, 256, 256], type=int, help='conv filters')
        self.parser.add_argument('--pool_res', nargs='+', default= [600, 450, 300, 180], type=int, help='pooling res')
        self.parser.add_argument('--norm', type=str, default='group',help='instance normalization or batch normalization or group normalization')
        self.parser.add_argument('--num_groups', type=int, default=16, help='# of groups for groupnorm')
        self.parser.add_argument('--init_type', type=str, default='normal', help='network initialization [normal|xavier|kaiming|orthogonal]')
        self.parser.add_argument('--init_gain', type=float, default=0.02, help='scaling factor for normal, xavier and orthogonal.')
        # general params
        self.parser.add_argument('--num_threads', default=3, type=int, help='# threads for loading data')
        self.parser.add_argument('--gpu_ids', type=str, default='0', help='gpu ids: e.g. 0  0,1,2, 0,2. use -1 for CPU')
        self.parser.add_argument('--name', type=str, default='debug', help='name of the experiment. It decides where to store samples and models')
        self.parser.add_argument('--checkpoints_dir', type=str, default='./checkpoints', help='models are saved here')
        self.parser.add_argument('--serial_batches', action='store_true', help='if true, takes meshes in order, otherwise takes them randomly')
        self.parser.add_argument('--seed', type=int, help='if specified, uses seed')
        # visualization params
        self.parser.add_argument('--export_folder', type=str, default='', help='exports intermediate collapses to this folder')
        #
        self.initialized = True

    def parse(self):
        if not self.initialized:
            self.initialize()
        self.opt, unknown = self.parser.parse_known_args()
        self.opt.is_train = self.is_train   # train or test

        str_ids = self.opt.gpu_ids.split(',')
        self.opt.gpu_ids = []
        for str_id in str_ids:
            id = int(str_id)
            if id >= 0:
                self.opt.gpu_ids.append(id)
        # set gpu ids
        if len(self.opt.gpu_ids) > 0:
            torch.cuda.set_device(self.opt.gpu_ids[0])

        args = vars(self.opt)

        if self.opt.seed is not None:
            import numpy as np
            import random
            torch.manual_seed(self.opt.seed)
            np.random.seed(self.opt.seed)
            random.seed(self.opt.seed)

        if self.opt.export_folder:
            self.opt.export_folder = os.path.join(self.opt.checkpoints_dir, self.opt.name, self.opt.export_folder)
            util.mkdir(self.opt.export_folder)

        if self.is_train:
            print('------------ Options -------------')
            for k, v in sorted(args.items()):
                print('%s: %s' % (str(k), str(v)))
            print('-------------- End ----------------')

            # save to the disk
            expr_dir = os.path.join(self.opt.checkpoints_dir, self.opt.name)
            util.mkdir(expr_dir)

            file_name = os.path.join(expr_dir, 'opt.txt')
            with open(file_name, 'wt') as opt_file:
                opt_file.write('------------ Options -------------\n')
                for k, v in sorted(args.items()):
                    opt_file.write('%s: %s\n' % (str(k), str(v)))
                opt_file.write('-------------- End ----------------\n')
        return self.opt


In [24]:
class TrainOptions(BaseOptions):
    def initialize(self):
        BaseOptions.initialize(self)
        self.parser.add_argument('--print_freq', type=int, default=10, help='frequency of showing training results on console')
        self.parser.add_argument('--save_latest_freq', type=int, default=250, help='frequency of saving the latest results')
        self.parser.add_argument('--save_epoch_freq', type=int, default=1, help='frequency of saving checkpoints at the end of epochs')
        self.parser.add_argument('--run_test_freq', type=int, default=1, help='frequency of running test in training script')
        self.parser.add_argument('--continue_train', action='store_true', help='continue training: load the latest model')
        self.parser.add_argument('--epoch_count', type=int, default=1, help='the starting epoch count, we save the model by <epoch_count>, <epoch_count>+<save_latest_freq>, ...')
        self.parser.add_argument('--phase', type=str, default='train', help='train, val, test, etc')
        self.parser.add_argument('--which_epoch', type=str, default='latest', help='which epoch to load? set to latest to use latest cached model')
        self.parser.add_argument('--niter', type=int, default=100, help='# of iter at starting learning rate')
        self.parser.add_argument('--niter_decay', type=int, default=100, help='# of iter to linearly decay learning rate to zero')
        self.parser.add_argument('--beta1', type=float, default=0.9, help='momentum term of adam')
        self.parser.add_argument('--lr', type=float, default=0.0002, help='initial learning rate for adam')
        self.parser.add_argument('--lr_policy', type=str, default='lambda', help='learning rate policy: lambda|step|plateau')
        self.parser.add_argument('--lr_decay_iters', type=int, default=50, help='multiply by a gamma every lr_decay_iters iterations')
        # data augmentation stuff
        self.parser.add_argument('--num_aug', type=int, default=20, help='# of augmentation files')
        self.parser.add_argument('--scale_verts', action='store_true', help='non-uniformly scale the mesh e.g., in x, y or z')
        self.parser.add_argument('--slide_verts', type=float, default=0.2, help='percent vertices which will be shifted along the mesh surface')
        self.parser.add_argument('--flip_edges', type=float, default=0.2, help='percent of edges to randomly flip')
        # tensorboard visualization
        self.parser.add_argument('--no_vis', action='store_true', help='will not use tensorboard')
        self.parser.add_argument('--verbose_plot', action='store_true', help='plots network weights, etc.')
        self.is_train = True

### (a) Calling Parameters from Argparse

In [25]:
opt = TrainOptions().parse()

------------ Options -------------
arch: mconvnet
batch_size: 16
beta1: 0.9
checkpoints_dir: ./checkpoints
continue_train: False
dataroot: datasets/shrec_16
dataset_mode: classification
epoch_count: 1
export_folder: 
fc_n: 100
flip_edges: 0.2
gpu_ids: [0]
init_gain: 0.02
init_type: normal
is_train: True
lr: 0.0002
lr_decay_iters: 50
lr_policy: lambda
max_dataset_size: inf
name: debug
ncf: [64, 128, 256, 256]
ninput_edges: 750
niter: 100
niter_decay: 100
no_vis: False
norm: group
num_aug: 20
num_groups: 16
num_threads: 3
phase: train
pool_res: [600, 450, 300, 180]
print_freq: 10
resblocks: 1
run_test_freq: 1
save_epoch_freq: 1
save_latest_freq: 250
scale_verts: False
seed: None
serial_batches: False
slide_verts: 0.2
verbose_plot: False
which_epoch: latest
-------------- End ----------------


### (b) Initializing the dataset and network
 (i) model.net => the entire network architecture <br />

In [26]:
opt.is_train = False
dataset = DataLoader(opt)
print('Dataset Size is: ', len(dataset))
model = ClassifierModel(opt)

loaded mean / std from cache
Dataset Size is:  480
Norm layer functools.partial(<class 'torch.nn.modules.normalization.GroupNorm'>, affine=True, num_groups=16)
Nf0 5
Number of Convolutional Filter [64, 128, 256, 256]
Number of Classes 30
Number of Input Edges 750
Edge Resolution Across Layers [600, 450, 300, 180]
Number of nodes is penultimate fc layer 100
Number of Residual Blocks:: Skip Connections 1
In Channels Start : [5, 64, 128, 256, 256]
loading the model from ./checkpoints/debug/latest_net.pth


## The Training Paradigm 

 - Writer, is a Summary Writer to note the important values during backpropogation.  <br />  <br />

 - model.set_input(data) => Sets the input data of the network <br />  <br />

 - model.optimize_parameters => A call to optimize the network parameters. 

In [27]:
writer = Writer(opt)
total_steps = 0
for epoch in range(opt.epoch_count, opt.niter + opt.niter_decay + 1):
    epoch_start_time = time.time()
    iter_data_time = time.time()
    epoch_iter = 0
    for i, data in enumerate(dataset):
        iter_start_time = time.time()
        if total_steps % opt.print_freq == 0:
            t_data = iter_start_time - iter_data_time
        total_steps += opt.batch_size
        epoch_iter += opt.batch_size
        model.set_input(data)
        model.optimize_parameters()

        if total_steps % opt.print_freq == 0:
            loss = model.loss
            t = (time.time() - iter_start_time) / opt.batch_size
            writer.print_current_losses(epoch, epoch_iter, loss, t, t_data)
            writer.plot_loss(loss, epoch, epoch_iter, dataset_size)

        if i % opt.save_latest_freq == 0:
            print('saving the latest model (epoch %d, total_steps %d)' %
                      (epoch, total_steps))
            model.save_network('latest')

        iter_data_time = time.time()
    if epoch % opt.save_epoch_freq == 0:
        print('saving the model at the end of epoch %d, iters %d' %
                  (epoch, total_steps))
        model.save_network('latest')
        model.save_network(epoch)

    print('End of epoch %d / %d \t Time Taken: %d sec' %
              (epoch, opt.niter + opt.niter_decay, time.time() - epoch_start_time))
    model.update_learning_rate()
    if opt.verbose_plot:
        writer.plot_model_wts(model, epoch)

    
    if epoch % opt.run_test_freq == 0:
        acc = run_test(epoch)
        writer.plot_acc(acc, epoch)
writer.close()

  return array(a, dtype, copy=False, order=order, subok=True)
  return array(a, dtype, copy=False, order=order, subok=True)
  return array(a, dtype, copy=False, order=order, subok=True)


AttributeError: 'NoneType' object has no attribute 'zero_grad'

In [28]:
model.net

DataParallel(
  (module): MeshConvNet(
    (conv0): MResConv(
      (conv0): MeshConv(
        (conv): Conv2d(5, 64, kernel_size=(1, 5), stride=(1, 1), bias=False)
      )
      (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (conv1): MeshConv(
        (conv): Conv2d(64, 64, kernel_size=(1, 5), stride=(1, 1), bias=False)
      )
    )
    (norm0): GroupNorm(16, 64, eps=1e-05, affine=True)
    (pool0): MeshPool()
    (conv1): MResConv(
      (conv0): MeshConv(
        (conv): Conv2d(64, 128, kernel_size=(1, 5), stride=(1, 1), bias=False)
      )
      (bn1): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (conv1): MeshConv(
        (conv): Conv2d(128, 128, kernel_size=(1, 5), stride=(1, 1), bias=False)
      )
    )
    (norm1): GroupNorm(16, 128, eps=1e-05, affine=True)
    (pool1): MeshPool()
    (conv2): MResConv(
      (conv0): MeshConv(
        (conv): Conv2d(128, 256, kernel_size=(1, 5), stride=(1,

## Breaking down the convolution 

### Layer 1: A MeshConv Layer: 

#### MResConv + Group Norm + MeshPool

- Input Edge features = (bs, 5, 750)

- Model mesh 

- Input Size  = (edges) => (16,5,750) + (mesh) 

- The Mesh is nothing but a (mesh.obj) file


### Layer 1: Output mesh has 64 dimensions

In [29]:
model.set_input(data)

In [30]:
out1 = model.net.module.conv0(model.edge_features,model.mesh)

torch.Size([16, 5, 750])
torch.Size([16, 64, 750, 1])


In [31]:
out1 = model.net.module.norm0(out1)

In [32]:
out1 = model.net.module.pool0(out1, model.mesh) 

In [27]:
out1.shape

torch.Size([16, 64, 600])

### Layer 2: A Mesh Conv in residual format. 

In [28]:
out2 = model.net.module.conv1(out1,model.mesh)

In [29]:
out2 = model.net.module.norm1(out2)

In [30]:
out2 = model.net.module.pool1(out2, model.mesh)

### Layer3: A Similar MeshConv residual layer

In [31]:
out3 = model.net.module.conv2(out2,model.mesh)

In [32]:
out3 = model.net.module.norm2(out3)

In [33]:
out3 = model.net.module.pool2(out3, model.mesh)

### Layer 4: The last MeshConv Layer before final MLP classification

In [34]:
out4 = model.net.module.conv3(out3,model.mesh)

In [35]:
out4 = model.net.module.norm3(out4)

In [36]:
out4 = model.net.module.pool3(out4, model.mesh)

### Layer 5: Average Pooling and Fully Connected Layers

In [37]:
out5 = model.net.module.gp(out4)
out5 = out5.view(-1, model.net.module.k[-1])

### Layer 6: Obtaining the final classified layer

In [38]:
out6 = F.relu(model.net.module.fc1(out5))
out6 = model.net.module.fc2(out6)

### MeshResConv: => Next Layer of Convolution

In [39]:
model.set_input(data)

In [40]:
inter1 = model.net.module.conv0.conv0(model.edge_features, model.mesh)

In [41]:
x = inter1

In [42]:
for i in range(model.net.module.conv0.skips):
        x = getattr(model.net.module.conv0, 'bn{}'.format(i + 1))(F.relu(x))
        x = getattr(model.net.module.conv0, 'conv{}'.format(i + 1))(x, model.mesh)
x += inter1

In [43]:
make_dot(inter1).render("meshconv", format="png")

'meshconv.png'

### Smallest Layer: The Mesh Convolutional Layers

In [44]:
model.set_input(data)
x = model.edge_features

In [65]:
x = x.squeeze(-1)

In [47]:
sample_mesh = model.mesh[0] #A sample mesh for forward pass. 

In [48]:
xsz = x.shape[2] #Number of edges 

In [63]:
sample_mesh.gemm_edges.shape

(180, 4)

### (i) Entering the padd_gemm function inside the architecture. 

In [64]:
padded_gemm = torch.tensor(sample_mesh.gemm_edges, device= x.device).float()

#### (a) Adding a numbered list as the edges list

In [65]:
padded_gemm = padded_gemm.requires_grad_()
padded_gemm = torch.cat((torch.arange(sample_mesh.edges_count, device=x.device).float().unsqueeze(1), padded_gemm), dim=1)

#### (b) Padding to ensure size is same as edge count

In [66]:
padded_gemm = F.pad(padded_gemm, (0, 0, 0, xsz - sample_mesh.edges_count), "constant", 0) #Adding zeros to complete the padding
padded_gemm = padded_gemm.unsqueeze(0)

In [67]:
padded_gemm.shape

torch.Size([1, 750, 5])

#### (c) Assigning padded_gemm to a variable Gi

In [68]:
Gi = padded_gemm

### (ii) Proceeding to the create_GeMM function

In [69]:
Gishape = Gi.shape

# pad the first row of  every sample in batch with zeros
padding = torch.zeros((x.shape[0], x.shape[1], 1), requires_grad=True, device=x.device)

#### (a) Flattening the GeMM indices

In [70]:
# padding = padding.to(x.device)
y  = x
y = torch.cat((padding, y), dim=2)
Gi = Gi + 1 #shift

In [184]:
#Gi_flat = model.net.module.conv0.conv0.flatten_gemm_inds(Gi)

In [97]:
(b, ne, nn) = Gi.shape
ne += 1
batch_n = torch.floor(torch.arange(b * ne, device=Gi.device).float() / ne).view(b, ne)
add_fac = batch_n * ne
add_fac = add_fac.view(b, ne, 1)
add_fac = add_fac.repeat(1, 1, nn)
# flatten Gi
Gi_flat = Gi.float() + add_fac[:, 1:, :]

In [102]:
add_fac.shape

torch.Size([1, 751, 5])

In [98]:
# first flatten indices
Gi_flat = Gi_flat.view(-1).long()

In [156]:
Gi.shape

torch.Size([1, 750, 5])

#### (b) Applying permuations: To switch the dimensions

In [157]:
odim = y.shape
y = y.permute(0, 2, 1).contiguous()
y = y.view(odim[0] * odim[2], odim[1])

In [158]:
f = torch.index_select(y, dim=0, index=Gi_flat)

In [162]:
f.shape

torch.Size([1, 5, 750, 5])

In [159]:
f = f.view(Gishape[0], Gishape[1], Gishape[2], -1)

In [161]:
f = f.permute(0, 3, 1, 2)

### Apply the symmetric functions for an equivariant conv

In [163]:
x_1 = f[:, :, :, 1] + f[:, :, :, 3]
x_2 = f[:, :, :, 2] + f[:, :, :, 4]
x_3 = torch.abs(f[:, :, :, 1] - f[:, :, :, 3])
x_4 = torch.abs(f[:, :, :, 2] - f[:, :, :, 4])
f = torch.stack([f[:, :, :, 0], x_1, x_2, x_3, x_4], dim=3)

In [164]:
# This is followed by a simple convolution operation: to obtain the result

In [167]:
f.shape

torch.Size([1, 5, 750, 5])

tensor([[[  1., 104., 141., 154.,  55.],
         [  2.,   3.,  84., 112., 148.],
         [  3.,  84.,   2., 107., 158.],
         ...,
         [  1.,   1.,   1.,   1.,   1.],
         [  1.,   1.,   1.,   1.,   1.],
         [  1.,   1.,   1.,   1.,   1.]]], device='cuda:0',
       grad_fn=<AddBackward0>)