**Importing Libraries**

In [None]:
import numpy as np
import tensorflow as tf
from tensorflow.keras import layers
from tensorflow.python.ops import array_ops
from tensorflow.python.ops import math_ops
from tensorflow.python.ops import nn

**Creating Edges of the Graph**

In [None]:
def pairwise_distance(x):
    """
    Compute pairwise distance of a point cloud.
    Args:
        x: tensor (batch_size, num_points, num_dims)
    Returns:
        pairwise distance: (batch_size, num_points, num_points)
    """
    with tf.name_scope('pairwise_distance'):
        x_inner = -2 * tf.matmul(x, tf.transpose(x, perm=[0, 2, 1]))
        x_square = tf.reduce_sum(tf.square(x), axis=-1, keepdims=True)
        return x_square + x_inner + tf.transpose(x_square, perm=[0, 2, 1])

def dense_knn_matrix(x, k=16, relative_pos=None):
    """Get KNN based on the pairwise distance.
    Args:
        x: (batch_size, num_dims, num_points, 1)
        k: int
    Returns:
        nearest neighbors: (batch_size, num_points, k) (batch_size, num_points, k)
    """
    with tf.name_scope('dense_knn_matrix'):
        x = tf.squeeze(tf.transpose(x, perm=[0, 2, 1, 3]), axis=-1)
        batch_size, n_points, n_dims = tf.unstack(tf.shape(x))
        ### memory efficient implementation ###
        n_part = 10000
        if n_points > n_part:
            nn_idx_list = []
            groups = tf.math.ceil(n_points / n_part)
            for i in tf.range(groups):
                start_idx = n_part * i
                end_idx = tf.minimum(n_points, n_part * (i + 1))
                dist = part_pairwise_distance(x, start_idx, end_idx)
                if relative_pos is not None:
                    dist += relative_pos[:, start_idx:end_idx]
                _, nn_idx_part = tf.math.top_k(-dist, k=k)
                nn_idx_list += [nn_idx_part]
            nn_idx = tf.concat(nn_idx_list, axis=1)
        else:
            dist = pairwise_distance(x)
            if relative_pos is not None:
                dist += relative_pos
            _, nn_idx = tf.math.top_k(-dist, k=k)
        ######
        center_idx = tf.transpose(tf.tile(tf.expand_dims(tf.range(0, n_points), axis=0), [batch_size, k, 1]), perm=[0, 2, 1])
    return tf.stack((nn_idx, center_idx), axis=0)

def xy_dense_knn_matrix(x, y, k=16, relative_pos=None):
    """Get KNN based on the pairwise distance.
    Args:
        x: (batch_size, num_dims, num_points, 1)
        k: int
    Returns:
        nearest neighbors: (batch_size, num_points, k) (batch_size, num_points, k)
    """
    with tf.name_scope('xy_dense_knn_matrix'):
        x = tf.squeeze(tf.transpose(x, perm=[0, 2, 1, 3]), axis=-1)
        y = tf.squeeze(tf.transpose(y, perm=[0, 2, 1, 3]), axis=-1)
        batch_size, n_points, n_dims = tf.unstack(tf.shape(x))
        dist = xy_pairwise_distance(x, y)
        if relative_pos is not None:
            dist += relative_pos
        _, nn_idx = tf.math.top_k(-dist, k=k)
        center_idx = tf.transpose(tf.tile(tf.expand_dims(tf.range(0, n_points), axis=0), [batch_size, k, 1]), perm=[0, 2, 1])
    return tf.stack((nn_idx, center_idx), axis=0)

class DenseDilated(tf.Module):
    """
    Find dilated neighbor from neighbor list

    edge_index: (2, batch_size, num_points, k)
    """
    def __init__(self, k=9, dilation=1, stochastic=False, epsilon=0.0):
        super(DenseDilated, self).__init__()
        self.dilation = dilation
        self.stochastic = stochastic
        self.epsilon = epsilon
        self.k = k

    def __call__(self, edge_index):
        with tf.name_scope('DenseDilated'):
            if self.stochastic:
                if tf.random.uniform(shape=()) < self.epsilon and tf.keras.backend.learning_phase():
                    num = self.k * self.dilation
                    randnum = tf.random.shuffle(tf.range(num))[:self.k]
                    edge_index = tf.gather(edge_index, randnum, axis=-1)
                else:
                    edge_index = edge_index[:, :, :, ::self.dilation]
            else:
                edge_index = edge_index[:, :, :, ::self.dilation]
            return edge_index

class DenseDilatedKnnGraph(tf.Module):
    """
    Find the neighbors' indices based on dilated knn
    """
    def __init__(self, k=9, dilation=1, stochastic=False, epsilon=0.0):
        super(DenseDilatedKnnGraph, self).__init__()
        self.dilation = dilation
        self.stochastic = stochastic
        self.epsilon = epsilon
        self.k = k
        self._dilated = DenseDilated(k, dilation, stochastic, epsilon)

    def __call__(self, x, y=None, relative_pos=None):
        with tf.name_scope('DenseDilatedKnnGraph'):
            if y is not None:
                #### normalize
                x = tf.nn.l2_normalize(x, axis=1)
                y = tf.nn.l2_normalize(y, axis=1)
                ####
                edge_index = xy_dense_knn_matrix(x, y, self.k * self.dilation, relative_pos)
            else:
                #### normalize
                x = tf.nn.l2_normalize(x, axis=1)
                ####
                edge_index = dense_knn_matrix(x, self.k * self.dilation, relative_pos)
            return self._dilated(edge_index)


**Creating Vertices of the Graph**

In [None]:
class BasicConv(layers.Layer):
    def __init__(self, filters, activation='relu', norm=None, use_bias=True):
        super(BasicConv, self).__init__()
        self.conv = layers.Conv2D(filters=filters[1], kernel_size=1, strides=1, padding='valid', use_bias=use_bias)
        self.batch_norm = layers.BatchNormalization() if norm == 'batch' else None
        self.activation = activation

    def call(self, inputs):
        x = self.conv(inputs)
        if self.batch_norm:
            x = self.batch_norm(x)
        if self.activation == 'relu':
            x = nn.relu(x)
        return x

def batched_index_select(x, indices):
    batch_size = tf.shape(indices)[0]
    batch_indices = tf.range(0, batch_size)
    indices = tf.stack([batch_indices, indices], axis=1)
    return tf.gather_nd(x, indices)

class MRConv2d(layers.Layer):
    def __init__(self, in_channels, out_channels, activation='relu', norm=None, use_bias=True):
        super(MRConv2d, self).__init__()
        self.nn = BasicConv([in_channels*2, out_channels], activation, norm, use_bias)

    def call(self, x, edge_index, y=None):
        x_i = batched_index_select(x, edge_index[1])
        if y is not None:
            x_j = batched_index_select(y, edge_index[0])
        else:
            x_j = batched_index_select(x, edge_index[0])
        x_j, _ = tf.math.reduce_max(x_j - x_i, axis=-1, keepdims=True)
        b, c, n, _ = x.shape
        x = tf.concat([tf.expand_dims(x, axis=2), tf.expand_dims(x_j, axis=2)], axis=2)
        x = tf.reshape(x, [b, 2 * c, n, _])
        return self.nn(x)

class EdgeConv2d(layers.Layer):
    def __init__(self, in_channels, out_channels, activation='relu', norm=None, use_bias=True):
        super(EdgeConv2d, self).__init__()
        self.nn = BasicConv([in_channels*2, out_channels], activation, norm, use_bias)

    def call(self, x, edge_index, y=None):
        x_i = batched_index_select(x, edge_index[1])
        if y is not None:
            x_j = batched_index_select(y, edge_index[0])
        else:
            x_j = batched_index_select(x, edge_index[0])
        max_value = tf.math.reduce_max(self.nn(tf.concat([x_i, x_j - x_i], axis=1)), axis=-1, keepdims=True)
        return max_value

class GraphSAGE(layers.Layer):
    def __init__(self, in_channels, out_channels, activation='relu', norm=None, use_bias=True):
        super(GraphSAGE, self).__init__()
        self.nn1 = BasicConv([in_channels, in_channels], activation, norm, use_bias)
        self.nn2 = BasicConv([in_channels*2, out_channels], activation, norm, use_bias)

    def call(self, x, edge_index, y=None):
        if y is not None:
            x_j = batched_index_select(y, edge_index[0])
        else:
            x_j = batched_index_select(x, edge_index[0])
        x_j, _ = tf.math.reduce_max(self.nn1(x_j), axis=-1, keepdims=True)
        return self.nn2(tf.concat([x, x_j], axis=1))

class GINConv2d(layers.Layer):
    def __init__(self, in_channels, out_channels, activation='relu', norm=None, use_bias=True):
        super(GINConv2d, self).__init__()
        self.nn = BasicConv([in_channels, out_channels], activation, norm, use_bias)
        eps_init = 0.0
        self.eps = tf.Variable(initial_value=[eps_init], trainable=True)

    def call(self, x, edge_index, y=None):
        if y is not None:
            x_j = batched_index_select(y, edge_index[0])
        else:
            x_j = batched_index_select(x, edge_index[0])
        x_j = tf.math.reduce_sum(x_j, axis=-1, keepdims=True)
        return self.nn((1 + self.eps) * x + x_j)

class GraphConv2d(layers.Layer):
    def __init__(self, in_channels, out_channels, conv='edge', activation='relu', norm=None, use_bias=True):
        super(GraphConv2d, self).__init__()
        if conv == 'edge':
            self.gconv = EdgeConv2d(in_channels, out_channels, activation, norm, use_bias)
        elif conv == 'mr':
            self.gconv = MRConv2d(in_channels, out_channels, activation, norm, use_bias)
        elif conv == 'sage':
            self.gconv = GraphSAGE(in_channels, out_channels, activation, norm, use_bias)
        elif conv == 'gin':
            self.gconv = GINConv2d(in_channels, out_channels, activation, norm, use_bias)
        else:
            raise NotImplementedError('conv:{} is not supported'.format(conv))

    def call(self, x, edge_index, y=None):
        return self.gconv(x, edge_index, y)

class DyGraphConv2d(GraphConv2d):
    def __init__(self, in_channels, out_channels, kernel_size=9, dilation=1, conv='edge', activation='relu',
                 norm=None, use_bias=True, stochastic=False, epsilon=0.0, r=1):
        super(DyGraphConv2d, self).__init__(in_channels, out_channels, conv, activation, norm, use_bias)
        self.k = kernel_size
        self.d = dilation
        self.r = r
        self.dilated_knn_graph = DenseDilatedKnnGraph(kernel_size, dilation, stochastic, epsilon)

    def call(self, x, relative_pos=None):
        B, C, H, W = x.shape
        y = None
        if self.r > 1:
            y = tf.nn.avg_pool2d(x, self.r, self.r)
            y = tf.reshape(y, [B, C, -1, 1])
        x = tf.reshape(x, [B, C, -1, 1])
        edge_index = self.dilated_knn_graph(x, y, relative_pos)
        x = super(DyGraphConv2d, self).call(x, edge_index, y)
        return tf.reshape(x, [B, -1, H, W])

class Grapher(layers.Layer):
    def __init__(self, in_channels, kernel_size=9, dilation=1, conv='edge', activation='relu', norm=None,
                 use_bias=True, stochastic=False, epsilon=0.0, r=1, n=196, drop_path=0.0, relative_pos=False):
        super(Grapher, self).__init__()
        self.channels = in_channels
        self.n = n
        self.r = r
        self.fc1 = tf.keras.Sequential([
            layers.Conv2D(filters=in_channels, kernel_size=1, strides=1, padding='valid', use_bias=use_bias),
            layers.BatchNormalization() if norm == 'batch' else None
        ])
        self.graph_conv = DyGraphConv2d(in_channels, in_channels * 2, kernel_size, dilation, conv,
                                        activation, norm, use_bias, stochastic, epsilon, r)
        self.fc2 = tf.keras.Sequential([
            layers.Conv2D(filters=in_channels, kernel_size=1, strides=1, padding='valid', use_bias=use_bias),
            layers.BatchNormalization() if norm == 'batch' else None
        ])
        self.drop_path = layers.Dropout(drop_path) if drop_path > 0.0 else tf.keras.layers.Identity()
        self.relative_pos = None
        if relative_pos:
            print('using relative_pos')
            relative_pos_tensor = tf.convert_to_tensor(np.float32(get_2d_relative_pos_embed(in_channels,
                int(n**0.5)))).unsqueeze(0).unsqueeze(1)
            relative_pos_tensor = tf.image.resize(relative_pos_tensor, (n, n//(r*r)), method='bicubic', antialias=False)
            self.relative_pos = tf.Variable(-relative_pos_tensor.squeeze(1), trainable=False)

    def _get_relative_pos(self, relative_pos, H, W):
        if relative_pos is None or H * W == self.n:
            return relative_pos
        else:
            N = H * W
            N_reduced = N // (self.r * self.r)
            return tf.image.resize(relative_pos.unsqueeze(0), (N, N_reduced), method='bicubic')

    def call(self, x):
        _tmp = x
        x = self.fc1(x)
        B, C, H, W = x.shape
        relative_pos = self._get_relative_pos(self.relative_pos, H, W)
        x = self.graph_conv(x, relative_pos)
        x = self.fc2(x)
        x = self.drop_path(x) + _tmp
        return x

**Main ViG Architecture**

In [None]:
class FFN(tf.keras.layers.Layer):
    def __init__(self, in_features, hidden_features=None, out_features=None, activation='relu', drop_rate=0.0):
        super(FFN, self).__init__()
        out_features = out_features or in_features
        hidden_features = hidden_features or in_features
        self.fc1 = tf.keras.layers.Conv2D(hidden_features, kernel_size=1, strides=1, padding='valid')
        self.bn1 = tf.keras.layers.BatchNormalization()
        self.activation = tf.keras.layers.Activation(activation)
        self.fc2 = tf.keras.layers.Conv2D(out_features, kernel_size=1, strides=1, padding='valid')
        self.bn2 = tf.keras.layers.BatchNormalization()
        self.drop = tf.keras.layers.Dropout(drop_rate)

    def call(self, inputs, training=None):
        shortcut = inputs
        x = self.fc1(inputs)
        x = self.bn1(x, training=training)
        x = self.activation(x)
        x = self.fc2(x)
        x = self.bn2(x, training=training)
        x = self.drop(x, training=training)
        x = x + shortcut
        return x

class Stem(tf.keras.layers.Layer):
    def __init__(self, out_dim=768, activation='relu'):
        super(Stem, self).__init__()
        self.convs = [
            tf.keras.layers.Conv2D(out_dim//8, kernel_size=3, strides=2, padding='same'),
            tf.keras.layers.BatchNormalization(),
            tf.keras.layers.Activation(activation),
            tf.keras.layers.Conv2D(out_dim//4, kernel_size=3, strides=2, padding='same'),
            tf.keras.layers.BatchNormalization(),
            tf.keras.layers.Activation(activation),
            tf.keras.layers.Conv2D(out_dim//2, kernel_size=3, strides=2, padding='same'),
            tf.keras.layers.BatchNormalization(),
            tf.keras.layers.Activation(activation),
            tf.keras.layers.Conv2D(out_dim, kernel_size=3, strides=2, padding='same'),
            tf.keras.layers.BatchNormalization(),
            tf.keras.layers.Activation(activation),
            tf.keras.layers.Conv2D(out_dim, kernel_size=3, strides=1, padding='same'),
            tf.keras.layers.BatchNormalization()
        ]
        self.stem_layers = tf.keras.Sequential(self.convs)

    def call(self, inputs):
        x = self.stem_layers(inputs)
        return x

class DeepGCN(tf.keras.Model):
    def __init__(self, opt):
        super(DeepGCN, self).__init__()
        channels = opt.n_filters
        k = opt.k
        act = opt.act
        norm = opt.norm
        bias = opt.bias
        epsilon = opt.epsilon
        stochastic = opt.use_stochastic
        conv = opt.conv
        self.n_blocks = opt.n_blocks
        drop_path = opt.drop_path
        
        self.stem = Stem(out_dim=channels, activation=act)

        dpr = [tf.linspace(0, drop_path, self.n_blocks)]  # stochastic depth decay rule 
        num_knn = [tf.cast(x, tf.int32) for x in tf.linspace(k, 2*k, self.n_blocks)]  # number of knn's k
        max_dilation = 196 // max(num_knn)
        
        self.pos_embed = tf.Variable(tf.zeros((1, 14, 14, channels)))

        if opt.use_dilation:
            self.backbone = [
                tf.keras.Sequential([
                    Grapher(channels, num_knn[i], min(i // 4 + 1, max_dilation), conv, act, norm,
                                                bias, stochastic, epsilon, 1, drop_path=dpr[i]),
                    FFN(channels, channels * 4, activation=act, drop_rate=dpr[i])
                ]) for i in range(self.n_blocks)
            ]
        else:
            self.backbone = [
                tf.keras.Sequential([
                    Grapher(channels, num_knn[i], 1, conv, act, norm,
                                                bias, stochastic, epsilon, 1, drop_path=dpr[i]),
                    FFN(channels, channels * 4, activation=act, drop_rate=dpr[i])
                ]) for i in range(self.n_blocks)
            ]

        self.prediction = tf.keras.Sequential([
            tf.keras.layers.Conv2D(1024, kernel_size=1, bias_initializer='zeros'),
            tf.keras.layers.BatchNormalization(),
            tf.keras.layers.Activation(act),
            tf.keras.layers.Dropout(opt.dropout),
            tf.keras.layers.Conv2D(opt.n_classes, kernel_size=1, bias_initializer='zeros')
        ])

    def call(self, inputs):
        x = self.stem(inputs) + self.pos_embed
        B, H, W, C = x.shape
        
        for i in range(self.n_blocks):
            x = self.backbone[i](x)

        x = tf.reduce_mean(x, axis=[1, 2])
        return self.prediction(x)