In [None]:
import numpy as np
import matplotlib.pyplot as plt
import os
import logging
import warnings
warnings.filterwarnings("ignore")
logging.getLogger('tensorflow').setLevel(logging.ERROR)
os.environ["KMP_AFFINITY"] = "noverbose"
os.environ['TF_CPP_MIN_LOG_LEVEL'] = '3'
import tensorflow as tf
tf.get_logger().setLevel('ERROR')
tf.autograph.set_verbosity(3)

from tensorflow import keras
from tqdm.autonotebook import tqdm, trange
import scipy as sp

from tensorflow.keras.callbacks import EarlyStopping
from tensorflow.keras.layers import Dropout, Input
from tensorflow.keras.losses import CategoricalCrossentropy
from tensorflow.keras.models import Model
from tensorflow.keras.optimizers import Adam
from tensorflow.keras.regularizers import l2
from tensorflow.random import set_seed

from spektral.data.loaders import SingleLoader
from spektral.datasets.citation import Citation
from spektral.layers import GATConv,GCNConv
from spektral.transforms import LayerPreprocess
from spektral.data.graph import Graph
from spektral.data.dataset import Dataset
from spektral.data import BatchLoader
from spektral.utils import gcn_filter


import scipy.sparse as spp
from spektral.transforms.normalize_adj import NormalizeAdj
from spektral.data import DisjointLoader
from tensorflow.keras.regularizers import l2

from scipy.spatial.transform import Rotation as R

import polyscope as ps

## Graph-Based Machine Learning Model
Now we will train a graph neural network model to perform mesh segmentation.

### Redefine Augmentation Pipeline

Just a copy of the previous augmentation function which will enable active augmentation during training.

In [None]:
def Augmentation_pipeline(v_mat,ffd_scale = 0.35, scale=0.15, rotation=np.pi/4,translation=0.0):
    
    ffd = pygem.FFD([4, 4, 4])
    ffd.array_mu_x = np.random.uniform(low = -ffd_scale, high=ffd_scale , size=ffd.array_mu_x.shape)
    ffd.array_mu_y = np.random.uniform(low = -ffd_scale, high=ffd_scale , size=ffd.array_mu_y.shape)
    ffd.array_mu_z = np.random.uniform(low = -ffd_scale, high=ffd_scale , size=ffd.array_mu_z.shape)

    v_mat = ffd(v_mat)
    
    v_mat = (R.from_rotvec(np.random.uniform(low=-rotation,high=rotation,size=3)).as_matrix() @ v_mat.T).T
    
    v_mat *= (1-np.random.uniform(low=-scale,high=scale))
    
    v_mat = v_mat - v_mat.min(0)
    v_mat = v_mat + (1 - v_mat.max(0))/2
    
    v_mat += np.random.uniform(low=-translation,high=translation,size=3)
    
    return v_mat

### Defining A Graph Dataset

Here we have some functions and methods to load the dataset from the numpy file and split it into train and validation sets.

In [None]:
class Aircraft_Dataset(Dataset):

    def __init__(self, path = 'dataset.npy', subset="train", **kwargs):
        
        self.raw = np.load(path,allow_pickle=True)
        if subset == "train":
            self.n_samples = int(self.raw.shape[0] * 0.8)
            self.st = 0
        else:
            self.n_samples = self.raw.shape[0]
            self.st = int(self.raw.shape[0] * 0.8)
        
        super().__init__(**kwargs)

    def read(self):
        def make_graph(i):
            # Node features
            x = self.raw[i,0]
            # Edges
            a = spp.csr_matrix((np.ones(self.raw[i,-2].shape[0]),(self.raw[i,-2][:,0],self.raw[i,-2][:,1])),shape=[x.shape[0],x.shape[0]])
            
            return Graph(x=x, a=a, y=i)
        graphs = []
        for i in range(self.st, self.n_samples):
            g = make_graph(i)
            if not g is None:
                graphs.append(g)
        # We must return a list of Graph objects
        return graphs

In [None]:
gat_data = Aircraft_Dataset(transforms=[LayerPreprocess(GATConv)])
gat_data_val = Aircraft_Dataset(transforms=[LayerPreprocess(GATConv)], subset = "validation")

### Visualize Active Augmentation
Here is a quick visualization of the active augmentation

In [None]:
%matplotlib widget
mesh = gat_data[5].x
new_mesh = Augmentation_pipeline(mesh)

ax = plt.figure(figsize=(8,8)).add_subplot(111, projection='3d')
ax.scatter(*new_mesh.T,s=1.5)
ax.scatter(*mesh.T,s=0.5)
plt.show()

### Defining The Model

Here we have some model architectures for mesh prediction.

In [None]:
# 6 Layer graph attention model applies to node level information in the graph
class node_level(tf.keras.Model):

    def __init__(self, dropout=0.1, l2_reg = 2.5e-4 , **kwargs):
        super(node_level, self).__init__(**kwargs)

        self.gc_1 =  GATConv(
                        64,
                        attn_heads=8,
                        concat_heads=True,
                        dropout_rate=dropout,
                        activation="elu",
                        kernel_regularizer=l2(l2_reg),
                        attn_kernel_regularizer=l2(l2_reg),
                        bias_regularizer=l2(l2_reg))
        
        self.gc_2 = GATConv(
                        64,
                        attn_heads=8,
                        concat_heads=True,
                        dropout_rate=dropout,
                        activation="elu",
                        kernel_regularizer=l2(l2_reg),
                        attn_kernel_regularizer=l2(l2_reg),
                        bias_regularizer=l2(l2_reg))
        
        self.gc_3 = GATConv(
                        128,
                        attn_heads=8,
                        concat_heads=True,
                        dropout_rate=dropout,
                        activation="elu",
                        kernel_regularizer=l2(l2_reg),
                        attn_kernel_regularizer=l2(l2_reg),
                        bias_regularizer=l2(l2_reg))

        self.gc_4 = GATConv(
                        128,
                        attn_heads=8,
                        concat_heads=True,
                        dropout_rate=dropout,
                        activation="elu",
                        kernel_regularizer=l2(l2_reg),
                        attn_kernel_regularizer=l2(l2_reg),
                        bias_regularizer=l2(l2_reg))
        
#         self.gc_5 = GATConv(
#                         128,
#                         attn_heads=8,
#                         concat_heads=True,
#                         dropout_rate=dropout,
#                         activation="elu",
#                         kernel_regularizer=l2(l2_reg),
#                         attn_kernel_regularizer=l2(l2_reg),
#                         bias_regularizer=l2(l2_reg))

#         self.gc_6 = GATConv(
#                         128,
#                         attn_heads=8,
#                         concat_heads=True,
#                         dropout_rate=dropout,
#                         activation="elu",
#                         kernel_regularizer=l2(l2_reg),
#                         attn_kernel_regularizer=l2(l2_reg),
#                         bias_regularizer=l2(l2_reg))
        
        
        self.prd_dense_1 = tf.keras.layers.Dense(256)
        self.prd_cbn1 = tf.keras.layers.BatchNormalization()
        self.LRelu = tf.keras.layers.LeakyReLU(alpha=0.3)
        
        self.prd_dense_2 = tf.keras.layers.Dense(256)
        self.prd_cbn2 = tf.keras.layers.BatchNormalization()
        
        self.prd_dense_3 = tf.keras.layers.Dense(128)
        self.prd_cbn3 = tf.keras.layers.BatchNormalization()
        
        self.prd_dense_4 = tf.keras.layers.Dense(64)

    def call(self, inputs, training = True):
        
        x,a,_ = inputs
        
        x = self.gc_1([x,a], training = training)

        x = self.gc_2([x,a], training = training)
        
        x = self.gc_3([x,a], training = training)

        x = self.gc_4([x,a], training = training)

        # x = self.gc_5([x,a], training = training)

        # x = self.gc_6([x,a], training = training)
        
        x = self.prd_dense_1(x, training = training)
        x = self.prd_cbn1(x, training = training)
        x = self.LRelu(x)
        
        x = self.prd_dense_2(x, training = training)
        x = self.prd_cbn2(x, training = training)
        x = self.LRelu(x)
        
        x = self.prd_dense_3(x, training = training)
        x = self.prd_cbn3(x, training = training)
        x = self.LRelu(x)
        
        x = self.prd_dense_4(x, training = training)
        4
        return x

In [None]:
# Transformation Net, details of this are discussed in the PointNet Paper
class TransformationNet(tf.keras.Model):

    def __init__(self, input_dim = 3):
        super(TransformationNet, self).__init__()
        self.output_dim = input_dim
        
        self.conv0 = tf.keras.layers.Conv1D(64,1)
        self.ln0 = tf.keras.layers.LayerNormalization()
        self.conv1 = tf.keras.layers.Conv1D(128,1)
        self.ln1 = tf.keras.layers.LayerNormalization()
        self.conv2 = tf.keras.layers.Conv1D(1024,1)
        self.ln2 = tf.keras.layers.LayerNormalization()
        
        self.fc_0  = tf.keras.layers.Dense(512)
        self.ln3 = tf.keras.layers.LayerNormalization()
        self.fc_1  = tf.keras.layers.Dense(256)
        self.ln4 = tf.keras.layers.LayerNormalization()
        
        self.fc_2  = tf.keras.layers.Dense(input_dim*input_dim)
        
    def call(self, x):
        num_points = x.shape[0]
        # x = x.transpose(2, 1)
        x = tf.expand_dims(x,0)
        x = tf.keras.activations.relu(self.ln0(self.conv0(x)))
        x = tf.keras.activations.relu(self.ln1(self.conv1(x)))
        x = tf.keras.activations.relu(self.ln2(self.conv2(x)))
        
        
        
        x = tf.reduce_max(x,axis=1)
        x = tf.reshape(x,[1,1024])

        x = tf.keras.activations.relu(self.ln3(self.fc_0(x)))
        x = tf.keras.activations.relu(self.ln4(self.fc_1(x)))
        x = self.fc_2(x)

        identity_matrix = tf.eye(self.output_dim)

        x = tf.reshape(x,[self.output_dim,self.output_dim]) + identity_matrix
        return x

In [None]:
# 3 layer GAT with global level features extracted
class global_level(tf.keras.Model):

    def __init__(self, dropout=0.1, l2_reg = 2.5e-4 , **kwargs):
        super(global_level, self).__init__(**kwargs)

        self.gc_1 =  GATConv(
                        64,
                        attn_heads=8,
                        concat_heads=True,
                        dropout_rate=dropout,
                        activation="elu",
                        kernel_regularizer=l2(l2_reg),
                        attn_kernel_regularizer=l2(l2_reg),
                        bias_regularizer=l2(l2_reg))
        self.pool_1 = spektral.layers.SAGPool(0.5)
        
        self.gc_2 = GATConv(
                        128,
                        attn_heads=8,
                        concat_heads=True,
                        dropout_rate=dropout,
                        activation="elu",
                        kernel_regularizer=l2(l2_reg),
                        attn_kernel_regularizer=l2(l2_reg),
                        bias_regularizer=l2(l2_reg))
        self.pool_2 = spektral.layers.SAGPool(0.25)
        
        self.gc_3 = GATConv(
                        256,
                        attn_heads=8,
                        concat_heads=True,
                        dropout_rate=dropout,
                        activation="elu",
                        kernel_regularizer=l2(l2_reg),
                        attn_kernel_regularizer=l2(l2_reg),
                        bias_regularizer=l2(l2_reg))
        
        self.g_pool = spektral.layers.GlobalAttentionPool(256)
        
        self.prd_dense_1 = tf.keras.layers.Dense(256)
        self.prd_cbn1 = tf.keras.layers.BatchNormalization()
        self.LRelu = tf.keras.layers.LeakyReLU(alpha=0.3)
        
        self.prd_dense_2 = tf.keras.layers.Dense(256)
        self.prd_cbn2 = tf.keras.layers.BatchNormalization()
        
        self.prd_dense_3 = tf.keras.layers.Dense(128)
        self.prd_cbn3 = tf.keras.layers.BatchNormalization()
        
        self.prd_dense_4 = tf.keras.layers.Dense(64)

    def call(self, inputs, training = True):
        
        x,a,i = inputs
        
        x = self.gc_1([x,a], training = training)
        x,a,i = self.pool_1([x,a,i], training = training)
        
        x = self.gc_2([x,a], training = training)
        x,a,i = self.pool_2([x,a,i], training = training)
        
        x = self.gc_3([x,a], training = training)
        x = self.g_pool([x,i], training = training)
        
        x = self.prd_dense_1(x, training = training)
        x = self.prd_cbn1(x, training = training)
        x = self.LRelu(x)
        
        x = self.prd_dense_2(x, training = training)
        x = self.prd_cbn2(x, training = training)
        x = self.LRelu(x)
        
        x = self.prd_dense_3(x, training = training)
        x = self.prd_cbn3(x, training = training)
        x = self.LRelu(x)
        
        x = self.prd_dense_4(x, training = training)
        
        return x

In [None]:
# Original Model with local and global networks
class surface_classifier(tf.keras.Model):

    def __init__(self, dropout=0.1, l2_reg = 2.5e-4, active_augmentation = Augmentation_pipeline, budget = 50000, has_global = True, **kwargs):
        super(surface_classifier, self).__init__(**kwargs)

        self.has_global = has_global
        if has_global:
            self.g_m = global_level(dropout,l2_reg)
        self.n_m = node_level(dropout,l2_reg)

        self.inp_emb_1 = tf.keras.layers.Dense(64)
        self.inp_bn_1 = tf.keras.layers.BatchNormalization()

        self.inp_emb_2 = tf.keras.layers.Dense(128)
        self.inp_bn_2 = tf.keras.layers.BatchNormalization()

        self.inp_emb_3 = tf.keras.layers.Dense(128)
        self.inp_bn_3 = tf.keras.layers.BatchNormalization()
        
        self.prd_dense_1 = tf.keras.layers.Dense(256)
        self.prd_cbn1 = tf.keras.layers.BatchNormalization()
        self.LRelu = tf.keras.layers.LeakyReLU(alpha=0.3)
        
        self.prd_dense_2 = tf.keras.layers.Dense(128)
        self.prd_cbn2 = tf.keras.layers.BatchNormalization()
        
        self.prd_dense_3 = tf.keras.layers.Dense(64)
        self.prd_cbn3 = tf.keras.layers.BatchNormalization()
        
        self.prd_dense_4 = tf.keras.layers.Dense(4)

        self.budget = budget

        self.aug_f = active_augmentation
        
    def call(self, inputs, training = True):
        x,a,i,faces,l,inds = inputs
        
        x = self.inp_emb_1(x, training=training)
        x = self.inp_bn_1(x, training=training)
        x = self.LRelu(x)

        x = self.inp_emb_2(x, training=training)
        x = self.inp_bn_2(x, training=training)
        x = self.LRelu(x)

        x = self.inp_emb_3(x, training=training)
        x = self.inp_bn_3(x, training=training)
        x = self.LRelu(x)

        batch = [x,a,i]
        if self.has_global:
            g_features = self.g_m(batch, training = training)
        n_features = self.n_m(batch, training = training)
        
        if self.has_global:
            gather_g_f = tf.gather_nd(g_features,tf.expand_dims(inds,-1))
        gather_n_f = tf.gather_nd(n_features,tf.reshape(faces,[-1,1]))
        gather_n_f = tf.reduce_mean(tf.reshape(gather_n_f,[faces.shape[0],3,-1]), 1)

        if self.has_global:
            x = tf.concat([gather_n_f,gather_g_f],-1)
        else:
            x = gather_n_f
        
        x = self.prd_dense_1(x, training = training)
        x = self.prd_cbn1(x, training = training)
        x = self.LRelu(x)
        
        x = self.prd_dense_2(x, training = training)
        x = self.prd_cbn2(x, training = training)
        x = self.LRelu(x)
        
        x = self.prd_dense_3(x, training = training)
        x = self.prd_cbn3(x, training = training)
        x = self.LRelu(x)
        
        x = self.prd_dense_4(x, training = training)
        
        return tf.math.softmax(x)
    
    def get_batch(self,dataset,batch, training = True):
        
        faces = None
        base = 0
        for i,mesh in enumerate(dataset.raw[batch[1]]):
            if faces is None:
                faces = mesh[0,1]
                labels = mesh[0,-1]
                inds = np.zeros(shape=[faces.shape[0]]) + i
                base += mesh[0,0].shape[0]
            else:
                faces = np.concatenate([faces,mesh[0,1]+base],0)
                labels = np.concatenate([labels,mesh[0,-1]],0)
                inds = np.concatenate([inds,np.zeros(shape=[mesh[0,1].shape[0]]) + i],0)
                base += mesh[0,0].shape[0]

        if faces.shape[0] > self.budget:
            sub_ind = np.random.choice(faces.shape[0],size=self.budget, replace=False)
            faces = faces[sub_ind]
            labels = labels[sub_ind]
            inds = inds[sub_ind]
        
        if training:
            return [Augmentation_pipeline(batch[0][0]),batch[0][1],batch[0][2],faces,labels.astype(np.float32),inds.astype(np.int32)]
        else:
            return [batch[0][0],batch[0][1],batch[0][2],faces,labels.astype(np.float32),inds.astype(np.int32)]
            
    
    def evaluate(self, batch):
        cce = tf.keras.losses.CategoricalCrossentropy()
        m = tf.keras.metrics.CategoricalAccuracy()
        m.reset_state()
        
        y_pred = self.call(batch)

        y = batch[4]

        loss = cce(y,y_pred)

        m.update_state(y,y_pred)
        
        return loss, m.result().numpy()
    
    def get_training_step(self):
 
        def train_step(batch,optimizer,m):
            cce = tf.keras.losses.CategoricalCrossentropy()
            m.reset_state()
            with tf.GradientTape() as tape:
                y_pred = self.call(batch)
                
                y = batch[4]
                
                loss = cce(y,y_pred)
                
            gradients = tape.gradient(loss, self.trainable_variables)
            optimizer.apply_gradients(zip(gradients, self.trainable_variables))
            
            m.update_state(y,y_pred)
            
            
            return loss, m.result()
        
        return train_step
    
    def train(self, dataset, val_data, epochs = 10, lr = 1e-4, batch_size = 2):
        
        loader = DisjointLoader(dataset, batch_size=batch_size)
        loader_val = DisjointLoader(val_data, batch_size=batch_size*3)
        n_st = loader.steps_per_epoch
        n_st_val = loader_val.steps_per_epoch
        
        optimizer = tf.keras.optimizers.Adam(lr)
        
        train_step = self.get_training_step()
        m = tf.keras.metrics.CategoricalAccuracy()
        for epoch in range(epochs):
            prog = trange(n_st)
            l_ov = 0.
            acc_ov = 0.
            n_t = 0
            for i in prog:
                loader_batch = loader.__next__()

                batch = self.get_batch(dataset,loader_batch)
                
                loss,acc = train_step(batch,optimizer,m)

                nf = batch[3].shape[0]
                n_t += nf

                l_ov += loss * nf
                acc_ov += acc * nf

                prog.set_postfix_str('Loss: %f, Ovrall Loss: %f, Accuracy: %f, Overall Accuracy: %f' % (loss,l_ov/n_t,acc,acc_ov/n_t))
                
            print('Epoch %i Loss: %f, Accuracy: %f' % (epoch+1, l_ov/n_t,acc_ov/n_t))
            
            # prog = trange(n_st_val)
            l_ov = 0.
            acc_ov = 0.
            n_t = 0
            for i in range(n_st_val):
                loader_batch = loader_val.__next__()

                batch = self.get_batch(dataset,loader_batch,training = False)
                
                loss,acc = self.evaluate(batch)

                nf = batch[3].shape[0]
                n_t += nf
                
                l_ov += loss * nf
                acc_ov += acc * nf

                # prog.set_postfix_str('Loss: %f, Ovrall Loss: %f, Accuracy: %f, Overall Accuracy: %f' % (loss,l_ov/n_t,acc,acc_ov/n_t))
                
            print('Epoch %i Validatio Loss: %f, Validation Accuracy: %f' % (epoch+1, l_ov/n_t,acc_ov/n_t))

In [None]:
# Current model with local and transformation net
class surface_classifier_with_tnet(tf.keras.Model):

    def __init__(self, dropout=0.1, l2_reg = 2.5e-4, active_augmentation = Augmentation_pipeline, budget = 50000, **kwargs):
        super(surface_classifier_with_tnet, self).__init__(**kwargs)

        self.t_net = TransformationNet()
        self.n_m = node_level(dropout,l2_reg)

        self.inp_emb_1 = tf.keras.layers.Dense(128)
        self.inp_bn_1 = tf.keras.layers.BatchNormalization()

        self.inp_emb_2 = tf.keras.layers.Dense(256)
        self.inp_bn_2 = tf.keras.layers.BatchNormalization()

        self.inp_emb_3 = tf.keras.layers.Dense(512)
        self.inp_bn_3 = tf.keras.layers.BatchNormalization()
        
        self.prd_dense_1 = tf.keras.layers.Dense(512)
        self.prd_cbn1 = tf.keras.layers.BatchNormalization()
        self.LRelu = tf.keras.layers.LeakyReLU(alpha=0.3)
        
        self.prd_dense_2 = tf.keras.layers.Dense(1024)
        self.prd_cbn2 = tf.keras.layers.BatchNormalization()
        
        self.prd_dense_3 = tf.keras.layers.Dense(1024)
        self.prd_cbn3 = tf.keras.layers.BatchNormalization()
        
        self.prd_dense_4 = tf.keras.layers.Dense(4)

        self.budget = budget

        self.aug_f = active_augmentation
        
    def call(self, inputs, training = True):
        x,a,i,faces,l,inds = inputs
        
        T = self.t_net(x)
        x = tf.transpose(tf.matmul(T,tf.transpose(x)))
        
        x = self.inp_emb_1(x, training=training)
        x = self.inp_bn_1(x, training=training)
        x = self.LRelu(x)

        x = self.inp_emb_2(x, training=training)
        x = self.inp_bn_2(x, training=training)
        x = self.LRelu(x)

        x = self.inp_emb_3(x, training=training)
        x = self.inp_bn_3(x, training=training)
        x = self.LRelu(x)

        batch = [x,a,i]
        n_features = self.n_m(batch, training = training)
        
        
        gather_n_f = tf.gather_nd(n_features,tf.reshape(faces,[-1,1]))
        gather_n_f = tf.reduce_mean(tf.reshape(gather_n_f,[tf.shape(faces)[0],3,-1]), 1)
        
        x = gather_n_f
        
        x = self.prd_dense_1(x, training = training)
        x = self.prd_cbn1(x, training = training)
        x = self.LRelu(x)
        
        x = self.prd_dense_2(x, training = training)
        x = self.prd_cbn2(x, training = training)
        x = self.LRelu(x)
        
        x = self.prd_dense_3(x, training = training)
        x = self.prd_cbn3(x, training = training)
        x = self.LRelu(x)
        
        x = self.prd_dense_4(x, training = training)
        
        return tf.math.softmax(x)
    
    def get_batch(self, dataset, batch, training = True):
        
        faces = None
        base = 0
        for i,mesh in enumerate(dataset.raw[batch[1]]):
            if faces is None:
                faces = mesh[0,1]
                labels = mesh[0,-1]
                inds = np.zeros(shape=[faces.shape[0]]) + i
                base += mesh[0,0].shape[0]
            else:
                faces = np.concatenate([faces,mesh[0,1]+base],0)
                labels = np.concatenate([labels,mesh[0,-1]],0)
                inds = np.concatenate([inds,np.zeros(shape=[mesh[0,1].shape[0]]) + i],0)
                base += mesh[0,0].shape[0]

        if faces.shape[0] > self.budget:
            sub_ind = np.random.choice(faces.shape[0],size=self.budget, replace=False)
            faces = faces[sub_ind]
            labels = labels[sub_ind]
            inds = inds[sub_ind]
        
        if training:
            return [Augmentation_pipeline(batch[0][0]).astype(np.float32),batch[0][1],batch[0][2],faces,labels.astype(np.float32),inds.astype(np.int32)]
        else:
            return [batch[0][0].astype(np.float32),batch[0][1],batch[0][2],faces,labels.astype(np.float32),inds.astype(np.int32)]
            
    def evaluate(self, batch):
        cce = tf.keras.losses.CategoricalCrossentropy()
        m = tf.keras.metrics.CategoricalAccuracy()
        m.reset_state()
        
        y_pred = self.call(batch)

        y = batch[4]

        loss = cce(y,y_pred)

        m.update_state(y,y_pred)
        
        return loss, m.result()
    
    def get_training_step(self):
        @tf.function(reduce_retracing=True)
        def train_step(batch,optimizer,m):
            cce = tf.keras.losses.CategoricalCrossentropy()
            m.reset_state()
            with tf.GradientTape() as tape:
                y_pred = self.call(batch)
                
                y = batch[4]
                
                loss = cce(y,y_pred)
                
            gradients = tape.gradient(loss, self.trainable_variables)
            optimizer.apply_gradients(zip(gradients, self.trainable_variables))
            
            m.update_state(y,y_pred)
            
            
            return loss, m.result()
        
        return train_step
    
    def evaluate_on_data(self, val_data):
        loader_val = DisjointLoader(val_data, batch_size=1)
        n_st_val = loader_val.steps_per_epoch
        l_ov = 0.
        acc_ov = 0.
        n_t = 0
        for i in range(n_st_val):
            loader_batch = loader_val.__next__()

            batch = self.get_batch(val_data,loader_batch,training = False)

            loss,acc = self.evaluate(batch)

            nf = batch[3].shape[0]
            n_t += nf

            l_ov += loss * nf
            acc_ov += acc * nf

        print('Validatio Loss: %f, Validation Accuracy: %f' % ( l_ov/n_t,acc_ov/n_t))
        
        return l_ov/n_t,acc_ov/n_t
    
    def train(self, dataset, val_data, epochs = 150, lr = 1e-4, save_name=None):
        
        loader = DisjointLoader(dataset, batch_size=1)
        loader_val = DisjointLoader(val_data, batch_size=1)
        n_st = loader.steps_per_epoch
        n_st_val = loader_val.steps_per_epoch
        
        lr_schedule = tf.keras.optimizers.schedules.ExponentialDecay(lr,
                                                                     decay_steps= n_st*(epochs//15),
                                                                     decay_rate=0.65,
                                                                     staircase=True)
        
        optimizer = tf.keras.optimizers.Adam(learning_rate=lr_schedule)
        
        train_step = self.get_training_step()
        m = tf.keras.metrics.CategoricalAccuracy()
        
        best_val = 0.0
        
        for epoch in range(epochs):
            prog = trange(n_st)
            l_ov = 0.
            acc_ov = 0.
            n_t = 0
            for i in prog:
                loader_batch = loader.__next__()

                batch = self.get_batch(dataset,loader_batch)
                loss,acc = train_step(batch,optimizer,m)

                nf = batch[3].shape[0]
                n_t += nf

                l_ov += loss * nf
                acc_ov += acc * nf

                prog.set_postfix_str('Loss: %f, Ovrall Loss: %f, Accuracy: %f, Overall Accuracy: %f, Learning Rate: %e' % (loss,l_ov/n_t,acc,acc_ov/n_t,optimizer._decayed_lr(tf.float32)))
                
            print('Epoch %i Loss: %f, Accuracy: %f' % (epoch+1, l_ov/n_t,acc_ov/n_t))
            
            # prog = trange(n_st_val)
            l_ov = 0.
            acc_ov = 0.
            n_t = 0
            for i in range(n_st_val):
                loader_batch = loader_val.__next__()

                batch = self.get_batch(dataset,loader_batch,training = False)
                
                loss,acc = self.evaluate(batch)

                nf = batch[3].shape[0]
                n_t += nf
                
                l_ov += loss * nf
                acc_ov += acc * nf

                # prog.set_postfix_str('Loss: %f, Ovrall Loss: %f, Accuracy: %f, Overall Accuracy: %f' % (loss,l_ov/n_t,acc,acc_ov/n_t))
                
            print('Epoch %i Validatio Loss: %f, Validation Accuracy: %f' % (epoch+1, l_ov/n_t,acc_ov/n_t))
            
            if acc_ov/n_t>=best_val:
                best_val = acc_ov/n_t
                if not save_name is None:
                    self.save_weights(save_name)

### Training The Model On the Data
Now we will train the model. Note that best performinh checkpoints are saved in the same folder with save_name indicated in the training function.

In [None]:
test_models = surface_classifier_with_tnet(budget=500000)

In [None]:
# Training and save best model in the CheckPoints Folder takes about 4 hours to train (uncomment the line below to train)
# test_models.train(gat_data,gat_data_val,lr=1e-4, epochs=150,save_name="./CheckPoints/T-Net_Classifier")

# Or load a pre-trained checkpoint here if you do not want to spend the time training
test_models.load_weights('./CheckPoints/T-Net_Classifier')

### Evaluate The Trained Model
Below is a few lines of code to evaluate and visulize the model.

#### Loss And Accuracy on Validation Data
We will measure the accuracy of the model on data it has not seen

In [None]:
loss,acc = test_models.evaluate_on_data(gat_data_val)

#### Visualization Functions

Functions thaty allow rendering preditcions on the models.

In [None]:
def render_prediction(v,f,l,save_file ='render.png', view='isometric',shadow=False,colorize=True, def_color=[0.7,0.7,0.7]):
    v_mat = v
    f_mat = f
    label_mat = l

    ps.set_autocenter_structures(True)
    ps.set_autoscale_structures(True)
    ps.init()

    if colorize:
        ps_mesh = ps.register_surface_mesh("my mesh", v_mat, f_mat[np.where(label_mat[:,0])[0]],color=[1.0,0.0,0.0])
        ps_mesh1 = ps.register_surface_mesh("my mesh 1", v_mat, f_mat[np.where(label_mat[:,1])[0]],color=[0.0,1.0,0.0])
        ps_mesh2 = ps.register_surface_mesh("my mesh 2", v_mat, f_mat[np.where(label_mat[:,2])[0]],color=[0.0,0.0,1.0])
        ps_mesh3 = ps.register_surface_mesh("my mesh 3", v_mat, f_mat[np.where(label_mat[:,3])[0]],color=[0.0,1.0,1.0])
    else:
        ps_mesh = ps.register_surface_mesh("my mesh", v_mat, f_mat,color=def_color)

    ps.set_ground_plane_height_factor(0.08)
    ps.set_up_dir("z_up")

    if shadow:
        ps.set_ground_plane_mode("shadow_only")
    else:
        ps.set_ground_plane_mode("none")


    if view == 'isometric':
        ps.look_at((-1.0,-1.0,1.0),(.0, .0, .0))
    elif view == 'top':
        ps.look_at((-0.0,-0.001,1.8),(.0, .0, .0))
    elif view == 'bottom':
        ps.look_at((-0.0,-0.001,-1.8),(.0, .0, .0))
    elif view == 'left':
        ps.look_at((-0.0,-1.8,0.0),(.0, .0, .0))
    elif view == 'right':
        ps.look_at((-0.0,1.8,0.0),(.0, .0, .0))
    elif view == 'front':
        ps.look_at((-1.8,0.0,0.0),(.0, .0, .0))
    elif view == 'back':
        ps.look_at((1.8,0.0,0.0),(.0, .0, .0))

    ps.screenshot(filename=save_file)
    
def render_centainty(v,f,l,l_pred,save_file ='render.png', view='isometric',shadow=False,colorize=True, def_color=[0.7,0.7,0.7]):
    v_mat = v
    f_mat = f
    label_mat = l

    ps.set_autocenter_structures(True)
    ps.set_autoscale_structures(True)
    ps.init()

    if colorize:
        for i,face in enumerate(f_mat):
            ps.register_surface_mesh("my mesh %i"%(i), v_mat, [face],color=[0.0,0.0,(l[i]*l_pred[i]).sum()])

    else:
        ps_mesh = ps.register_surface_mesh("my mesh", v_mat, f_mat,color=def_color)

    ps.set_ground_plane_height_factor(0.08)
    ps.set_up_dir("z_up")

    if shadow:
        ps.set_ground_plane_mode("shadow_only")
    else:
        ps.set_ground_plane_mode("none")


    if view == 'isometric':
        ps.look_at((-1.0,-1.0,1.0),(.0, .0, .0))
    elif view == 'top':
        ps.look_at((-0.0,-0.001,1.8),(.0, .0, .0))
    elif view == 'bottom':
        ps.look_at((-0.0,-0.001,-1.8),(.0, .0, .0))
    elif view == 'left':
        ps.look_at((-0.0,-1.8,0.0),(.0, .0, .0))
    elif view == 'right':
        ps.look_at((-0.0,1.8,0.0),(.0, .0, .0))
    elif view == 'front':
        ps.look_at((-1.8,0.0,0.0),(.0, .0, .0))
    elif view == 'back':
        ps.look_at((1.8,0.0,0.0),(.0, .0, .0))

    ps.screenshot(filename=save_file)

#### Other Metrics & And Rendering
The predictions are visualized in the visualization folder with pred files being prediction and gt files being ground truth.

In [None]:
m = tf.keras.metrics.TopKCategoricalAccuracy(k=2)
m.reset_state()
wrong_certainty = 0.
correct_certainty = 0.
nc = 0
nw = 0
loader = DisjointLoader(gat_data, batch_size=1,shuffle=False)
for g in tqdm(gat_data_val):
    loader_batch = loader.collate([g])
    ind = loader_batch[1][0][0]
    faces = gat_data_val.raw[ind,1]
    labels = gat_data_val.raw[ind,-1]
    x = loader_batch[0][0].astype(np.float32)
    a = loader_batch[0][1]
    i = loader_batch[0][2]
    
    l_pred = test_models([x,a,i,faces,labels,0])
    l_oh = tf.one_hot(tf.argmax(l_pred,axis=-1),4).numpy().astype(bool)
    
    render_prediction(x,faces,l_oh,"./visualization/%i_pred.png"%(ind))
    render_prediction(x,faces,labels,"./visualization/%i_gt.png"%(ind))
    m.update_state(labels,l_pred)
    
    l_wrong = l_pred.numpy()[(labels * l_oh).sum(-1)==0]
    wrong_certainty += l_wrong[list(range(l_wrong.shape[0])),np.argmax(l_wrong,axis=-1)].sum()
    
    l_correct = l_pred.numpy()[(labels * l_oh).sum(-1)!=0]
    correct_certainty += l_correct[list(range(l_correct.shape[0])),np.argmax(l_correct,axis=-1)].sum()
    
    nc += l_correct.shape[0]
    nw += l_wrong.shape[0]
    
print("Top 2 accuracy: %f"%(m.result().numpy()))
print("Certainty on Correct Predictions: %f"%(correct_certainty/nc))
print("Certainty on Incorrect Predictions: %f"%(wrong_certainty/nw))