In [1]:
import jax.numpy as jnp
from jax import random, jit, vmap, lax

class RawVersion:
    @staticmethod
    def conv2d(x, w, b, padding=1):
        bs, icl, he, wi = x.shape  # input graph -> batch_size x channel x height x width
        ocl, icl, kh, kw = w.shape
        he = (he + 2*padding - kh + 1)
        wi = (wi + 2*padding - kw + 1)
        
        fgraph = jnp.zeros((bs, ocl, he, wi))  # feature graph
    
        # padding for x 
        pad_mat = (
            (0, 0),
            (0, 0),
            (padding, padding),
            (padding, padding)
        )
        
        x_padded = jnp.pad(x, pad_mat, mode='constant', constant_values=0)
    
        for k in range(ocl):
            for i in range(he):
                for j in range(wi):
                    fgraph.at[:, k, i, j].set(
                        jnp.sum(x_padded[:, :, i:i + kh, j:j + kw] * w[k], axis=(1, 2, 3)) + b[k]
                    )
    
        return fgraph

    @staticmethod
    def max_pooling2d(x, pool_size=(2, 2), stride=None):
        if stride is None:
            stride = pool_size
        
        batch_size, channels, height, width = x.shape
        pool_height, pool_width = pool_size
        stride_height, stride_width = stride
        
        output_height = (height - pool_height) // stride_height + 1
        output_width = (width - pool_width) // stride_width + 1
        
        output_array = jnp.zeros((batch_size, channels, output_height, output_width))
        
        for n in range(batch_size):
            for c in range(channels):
                for i in range(output_height):
                    for j in range(output_width):
                        window = x[n, c, 
                                        i * stride_height:i * stride_height + pool_height, 
                                        j * stride_width :j * stride_width  + pool_width]
                        output_array.at[n, c, i, j].set(
                            jnp.max(window)
                        )
        
        return output_array        


class JaxOptimaized:
    @staticmethod
    def conv2d(x, w, b, padding=1):
        dimension_numbers = ('NCHW', 'OIHW', 'NCHW')
        padding_mode = ((padding, padding), (padding, padding))  # 高度和宽度方向的padding
        
        out = lax.conv_general_dilated(
            lhs=x,
            rhs=w,
            window_strides=(1, 1),
            padding=padding_mode,
            lhs_dilation=(1, 1),
            rhs_dilation=(1, 1),
            dimension_numbers=dimension_numbers
        )
        
        return out + b[None, :, None, None]

    @staticmethod
    def max_pooling2d(x, pool_size=(2, 2), stride=None):
        if stride is None:
            stride = pool_size
        
        return lax.reduce_window(
            operand=x,
            init_value=-jnp.inf,
            computation=lax.max,
            window_dimensions=(1, 1, pool_size[0], pool_size[1]),
            window_strides=(1, 1, stride[0], stride[1]),
            padding='VALID'
        ) 

In [2]:
batch_size = 10000
in_channel = 1
out_channel = 10
padding = 1

X = jnp.ones((batch_size, in_channel, 28, 28))
w = jnp.ones((out_channel, in_channel, 5, 5))
b = jnp.ones((out_channel, ))

bs, cl, he, wi = X.shape  # graph -> height x width
kh, kw = (5, 5)
he = (he + 2*padding - kh)
wi = (wi + 2*padding - kh)

jit_conv2d = jit(lambda w, b: JaxOptimaized.conv2d(X, w, b, padding=2))
jit_max_pooling2d = jit(JaxOptimaized.max_pooling2d)

# fgraph = JaxOptimaized.conv2d(w, b)
fgraph = jit_conv2d(w, b)

In [3]:
print(fgraph.shape)

(10000, 10, 26, 26)


In [4]:
# fgraph = JaxOptimaized.max_pooling2d(fgraph)
fgraph = jit_max_pooling2d(fgraph)

In [5]:
print(fgraph.shape)

(10000, 10, 13, 13)
