In [1]:
import os
os.environ["JAX_PLATFORM_NAME"] = "cpu"

import jax
import jax.numpy as jnp
import numpy as np
from jaxmao.layers import Conv2D, Dense, BatchNorm, ReLU, Flatten, StableSoftmax, BatchNorm2D, DepthwiseConv2D
from jaxmao.modules import Module
from jaxmao.optimizers import GradientDescent
from jaxmao.losses import CategoricalCrossEntropy
from jaxmao.metrics import Accuracy, Precision, Recall

In [2]:
jax.devices()

I0000 00:00:1697817978.975058  118380 tfrt_cpu_pjrt_client.cc:349] TfrtCpuClient created.


[CpuDevice(id=0)]

In [3]:
import jax.numpy as jnp
from jax import jit
import numpy as np

def my_jax_nn(input_data, weights, bias):
    output = jnp.dot(input_data, weights) + bias
    return output

# JIT compile for better performance
my_jax_nn = jit(my_jax_nn)

In [4]:
# Your custom weights and bias
custom_weights = np.array([
    [0.1, 0.2, 0.3],
    [0.4, 0.5, 0.6],
    [0.7, 0.8, 0.9],
    [1.0, 1.1, 1.2]
], dtype='float32')
custom_bias = np.array([0.01, 0.02, 0.03], dtype='float32')

# Your sample input
sample_input = np.arange(20*4).reshape(20, 4).astype('float32')

# Run the JAX function
jax_output = my_jax_nn(sample_input, custom_weights, custom_bias)
jax_output

Array([[  4.8100004,   5.42     ,   6.03     ],
       [ 13.610001 ,  15.82     ,  18.03     ],
       [ 22.41     ,  26.220001 ,  30.03     ],
       [ 31.210001 ,  36.62     ,  42.03     ],
       [ 40.01     ,  47.02     ,  54.03     ],
       [ 48.81     ,  57.420002 ,  66.03     ],
       [ 57.609997 ,  67.82     ,  78.03     ],
       [ 66.409996 ,  78.22     ,  90.03     ],
       [ 75.21     ,  88.62     , 102.03     ],
       [ 84.01     ,  99.02     , 114.03     ],
       [ 92.810005 , 109.420006 , 126.03     ],
       [101.61     , 119.82     , 138.03     ],
       [110.409996 , 130.22     , 150.03     ],
       [119.21     , 140.62001  , 162.03     ],
       [128.01     , 151.02     , 174.03     ],
       [136.81     , 161.42001  , 186.03     ],
       [145.61     , 171.82     , 198.03     ],
       [154.40999  , 182.22     , 210.03     ],
       [163.20999  , 192.62001  , 222.03     ],
       [172.01     , 203.02     , 234.03     ]], dtype=float32)

# Dense

In [5]:
# Define the input shape and number of output neurons
input_shape = (4,)  # Input shape (number of input features)
output_neurons = 3  # Number of output neurons

# Initialize custom weights and bias
custom_weights = np.array([
    [0.1, 0.2, 0.3],
    [0.4, 0.5, 0.6],
    [0.7, 0.8, 0.9],
    [1.0, 1.1, 1.2]
])
custom_bias = np.array([0.01, 0.02, 0.03])

In [6]:
dense = Dense(4, 3)
dense.params = dict()
dense.params['weights'] = custom_weights
dense.params['biases'] = custom_bias

bn = BatchNorm(3, momentum=0.5, axis_mean=0)
bn.init_params(None)
bn.set_inference_mode()

In [7]:
sample_input = np.arange(8*4).reshape(8, 4)  # Shape will be (1, 4)
output, _ = dense(dense.params, sample_input)
output

Array([[ 4.8100004,  5.42     ,  6.03     ],
       [13.610001 , 15.82     , 18.03     ],
       [22.41     , 26.220001 , 30.03     ],
       [31.210001 , 36.62     , 42.03     ],
       [40.01     , 47.02     , 54.03     ],
       [48.81     , 57.420002 , 66.03     ],
       [57.609997 , 67.82     , 78.03     ],
       [66.409996 , 78.22     , 90.03     ]], dtype=float32)

## Dense, bn

In [8]:
bn.eps = np.float32(1e-3) # this is what Keras use. I use 1e-5

output2, _ = bn(bn.params, output)
output2

Array([[ 4.807597 ,  5.417292 ,  6.0269876],
       [13.603201 , 15.812096 , 18.020992 ],
       [22.398804 , 26.206902 , 30.014997 ],
       [31.194408 , 36.601704 , 42.009003 ],
       [39.99001  , 46.99651  , 54.003006 ],
       [48.785618 , 57.391315 , 65.99701  ],
       [57.581215 , 67.78612  , 77.99101  ],
       [66.376816 , 78.18092  , 89.98502  ]], dtype=float32)

In [9]:
bn.params, bn.state

({'gamma': Array([1., 1., 1.], dtype=float32),
  'beta': Array([0., 0., 0.], dtype=float32)},
 {'running_mean': Array([0., 0., 0.], dtype=float32),
  'running_var': Array([1., 1., 1.], dtype=float32),
  'momentum': 0.5,
  'training': False})

In [10]:
output3 = ReLU()(np.array([[-5, 0, 5]]))
output3

Array([[0, 0, 5]], dtype=int32)

# Conv2D

In [11]:
sample_input = np.arange(25).reshape(1, 5, 5, 1).astype('float32')

#### valid padding

In [12]:
custom_kernel = np.array([[[[0.1]], [[0.2]], [[0.3]]],
                          [[[0.4]], [[0.5]], [[0.6]]],
                          [[[0.7]], [[0.8]], [[0.9]]]])

custom_bias = np.array([0.01])

conv = Conv2D(1, 1, (3,3), 1, padding='VALID')
conv.params = dict(weights=custom_kernel, biases=custom_bias)
# conv.params

In [13]:
conv_output, state = conv(conv.params, sample_input)
conv_output.shape, conv_output.ravel()

((1, 3, 3, 1),
 Array([36.609997, 41.109997, 45.609997, 59.109997, 63.609997, 68.11    ,
        81.61001 , 86.10999 , 90.61    ], dtype=float32))

#### same padding

In [14]:
custom_kernel = np.array([[[[0.1]], [[0.2]], [[0.3]]],
                          [[[0.4]], [[0.5]], [[0.6]]],
                          [[[0.7]], [[0.8]], [[0.9]]]])

custom_bias = np.array([0.01])

conv = Conv2D(1, 1, (3,3), 1, padding='SAME')
conv.params = dict(weights=custom_kernel, biases=custom_bias)
# conv.params

In [15]:
conv_output, state = conv(conv.params, sample_input)
conv_output.shape, conv_output.ravel()

((1, 5, 5, 1),
 Array([10.01    , 16.31    , 20.21    , 24.109999, 16.01    , 24.31    ,
        36.609997, 41.109997, 45.609997, 29.11    , 40.81    , 59.109997,
        63.609997, 68.11    , 42.609997, 57.31    , 81.61001 , 86.10999 ,
        90.61    , 56.110004, 30.410002, 41.51    , 43.609997, 45.71    ,
        26.81    ], dtype=float32))

# Conv2D, bn

In [16]:
sample_input = np.arange(25).reshape(1, 5, 5, 1).astype('float32')

#### valid padding

In [17]:
custom_kernel = np.array([[[[0.1]], [[0.2]], [[0.3]]],
                          [[[0.4]], [[0.5]], [[0.6]]],
                          [[[0.7]], [[0.8]], [[0.9]]]])

custom_bias = np.array([0.01])

conv = Conv2D(1, 1, (3,3), 1, padding='VALID')
conv.params = dict(weights=custom_kernel, biases=custom_bias)
# conv.params

In [18]:
conv_output, state = conv(conv.params, sample_input)
conv_output.shape, conv_output.ravel()

((1, 3, 3, 1),
 Array([36.609997, 41.109997, 45.609997, 59.109997, 63.609997, 68.11    ,
        81.61001 , 86.10999 , 90.61    ], dtype=float32))

In [19]:
convbn = BatchNorm(1, axis_mean=(0, 1, 2))
convbn.init_params()
convbn.set_inference_mode()

convbn.eps = np.float32(1e-3)

In [20]:
bn_out, state = convbn(convbn.params, conv_output)
bn_out.shape, bn_out.ravel()

((1, 3, 3, 1),
 Array([36.591705, 41.08946 , 45.58721 , 59.080467, 63.578217, 68.07597 ,
        81.56924 , 86.06697 , 90.564735], dtype=float32))

#### same padding

In [21]:
custom_kernel = np.array([[[[0.1]], [[0.2]], [[0.3]]],
                          [[[0.4]], [[0.5]], [[0.6]]],
                          [[[0.7]], [[0.8]], [[0.9]]]])

custom_bias = np.array([0.01])

conv = Conv2D(1, 1, (3,3), 1, padding='SAME')
conv.params = dict(weights=custom_kernel, biases=custom_bias)
# conv.params

In [22]:
conv_output, state = conv(conv.params, sample_input)
conv_output.shape, conv_output.ravel()

((1, 5, 5, 1),
 Array([10.01    , 16.31    , 20.21    , 24.109999, 16.01    , 24.31    ,
        36.609997, 41.109997, 45.609997, 29.11    , 40.81    , 59.109997,
        63.609997, 68.11    , 42.609997, 57.31    , 81.61001 , 86.10999 ,
        90.61    , 56.110004, 30.410002, 41.51    , 43.609997, 45.71    ,
        26.81    ], dtype=float32))

In [23]:
convbn = BatchNorm(1, axis_mean=(0, 1, 2))
convbn.init_params()
convbn.set_inference_mode()

convbn.eps = np.float32(1e-3)

In [24]:
bn_out, state = convbn(convbn.params, conv_output)
bn_out.shape, bn_out.ravel()

((1, 5, 5, 1),
 Array([10.004999, 16.301851, 20.199902, 24.097954, 16.002   , 24.297853,
        36.591705, 41.08946 , 45.58721 , 29.095457, 40.78961 , 59.080467,
        63.578217, 68.07597 , 42.58871 , 57.28137 , 81.56924 , 86.06697 ,
        90.564735, 56.081974, 30.39481 , 41.48926 , 43.58821 , 45.687164,
        26.796606], dtype=float32))

In [25]:
convbn = BatchNorm(1, axis_mean=(0, 1, 2))
convbn.init_params()
convbn.set_inference_mode()

convbn2d = BatchNorm2D(1)
convbn2d.init_params()
convbn2d.set_inference_mode()

In [26]:
convbn(convbn.params, bn_out)[0].ravel() == convbn2d(convbn2d.params, bn_out)[0].ravel()

Array([ True,  True,  True,  True,  True,  True,  True,  True,  True,
        True,  True,  True,  True,  True,  True,  True,  True,  True,
        True,  True,  True,  True,  True,  True,  True], dtype=bool)

# Depthwise conv2d

In [79]:
sample_input = np.arange(50).reshape(1, 5, 5, 2).astype('float32')

custom_depthwise_kernel = jnp.array([[[[0.1, 0.2]], [[0.3, 0.4]], [[0.5, 0.6]]],
                                    [[[0.7, 0.8]], [[0.9, 1.0]], [[1.1, 1.2]]],
                                    [[[1.3, 1.4]], [[1.5, 1.6]], [[1.7, 1.8]]]]).reshape(3, 3, 1, 2)

custom_bias = jnp.array([0.01, 0.02])

dwconv = DepthwiseConv2D(2, 1, (3, 3), 1, 'relu', padding='VALID')
dwconv.init_params(key=jax.random.PRNGKey(1))
# dwconv.params = dict()
# dwconv.params['weights'] = custom_depthwise_kernel
# dwconv.params['biases'] = custom_bias

In [80]:
dwconv.params['weights'].shape

(3, 3, 1, 2)

In [81]:
output, s = dwconv(dwconv.params, sample_input)
output.shape, output

((1, 3, 3, 2),
 Array([[[[31.62529 , 27.67883 ],
          [31.967676, 32.388893],
          [32.310062, 37.098957]],
 
         [[33.33722 , 51.22915 ],
          [33.679604, 55.939217],
          [34.02198 , 60.649284]],
 
         [[35.04914 , 74.77947 ],
          [35.391525, 79.48954 ],
          [35.733913, 84.19961 ]]]], dtype=float32))

# sublayers

In [82]:
import jax
import numpy as np
from jaxmao.layers import Layer, Dense

class TwoDense(Layer):
    def __init__(self, in_channels, out_channels):
        super().__init__()
        self.layers = {
            'dense1': Dense(4, 4, activation=None, use_bias=False),
            'dense2': Dense(4, 5, activation=None, use_bias=False)
        }
    
    def forward(self, params, x, state=None):
        x, state = self.apply(params, x, 'dense1', state)
        x, state = self.apply(params, x, 'dense2', state)
        return x
dw_sep = TwoDense(4, 5)
dw_sep.init_params(key=jax.random.PRNGKey(0))

In [83]:
dw_sep.layers['dense1'].activation

<jaxmao.layers.Linear at 0x7f2c51725650>

In [84]:
dw_sep.params

{'dense1': {'weights': Array([[ 0.222334  , -0.5509084 ,  0.18274157,  0.33478668],
         [-1.4232535 ,  0.15614432,  0.72407204,  0.73889905],
         [-0.07506051,  1.5381824 , -0.7009841 , -0.551843  ],
         [ 0.19835554,  0.05200171,  0.08526177,  0.5655228 ]],      dtype=float32)},
 'dense2': {'weights': Array([[ 0.92753756, -1.3217152 , -1.1210729 , -0.6640803 , -0.3569061 ],
         [-1.3121504 , -1.6131067 ,  1.0063031 ,  0.5350468 , -0.4035093 ],
         [-1.0435044 , -0.57557756,  0.46654668,  0.6476337 ,  0.2598184 ],
         [-0.37320292, -0.16586637,  0.32066888, -1.3481379 , -0.07582019]],      dtype=float32)}}

In [85]:
sample = np.random.normal(-2, 1, (5, 4))
dw_sep.forward(dw_sep.params, sample, dw_sep.state)

Array([[ 11.063606  ,   2.9078488 , -10.019278  ,  -2.7214916 ,
          0.49901664],
       [ 10.39292   ,   4.6138706 ,  -9.544006  ,  -3.522541  ,
          1.3076421 ],
       [  2.7269378 ,   4.02012   ,  -2.1544874 ,   1.8654947 ,
          1.2582066 ],
       [ 10.754398  ,   0.7795911 ,  -9.53331   ,  -0.3341658 ,
         -0.44027317],
       [  6.844033  ,   4.1617756 ,  -5.90447   ,   0.6955593 ,
          1.0799253 ]], dtype=float32)

In [86]:
print(dw_sep.summary)

layer                output shape         #'s params           #'s states          
dense1               (5, 4)               16                   0                   
dense2               (5, 5)               20                   0                   



# depthwise separable conv2d

In [87]:
import jax
import numpy as np
from jaxmao.layers import Layer, Dense, Conv2D, DepthwiseConv2D
from jaxmao.initializers import *

class DepthwiseSeparableConv2D(Layer):
    def __init__(
            self,
            in_channels,
            out_channels,
            depth_multiplier=1,
            activation='relu',
            padding='SAME',
            strides=(1, 1),
            dilation=(1, 1),
            use_bias=True,
            weights_initializer=HeNormal(),
            bias_initializer=HeNormal(),
            dtype=jnp.float32
                 ):
        super().__init__()
        self.layers = {
            'depthwise' : DepthwiseConv2D(
                                in_channels, depth_multiplier, 
                                kernel_size=(3, 3), strides=strides, 
                                activation=activation, padding=padding,
                                dilation=dilation, use_bias=use_bias,
                                weights_initializer=weights_initializer,
                                bias_initializer=bias_initializer,
                                dtype=dtype
                                          ),
            'pointwise' : Conv2D(
                            in_channels, out_channels, 
                            kernel_size=(1, 1), strides=(1, 1),
                            activation=activation,
                            dilation=dilation, use_bias=use_bias,
                            weights_initializer=weights_initializer,
                            bias_initializer=bias_initializer,
                            dtype=dtype
                                 )
        }
    
    def forward(self, params, x, state=None):
        x, state = self.apply(params, x, 'depthwise', state)
        x, state = self.apply(params, x, 'pointwise', state)
        return x, state

In [88]:
sample = np.random.normal(0, 1, (5, 12, 12, 5))

dwsep = DepthwiseSeparableConv2D(5, 7)
dwsep.init_params(key=jax.random.PRNGKey(4))

In [90]:
x, s = dwsep.forward(dwsep.params, sample, dwsep.state)
x.shape

(5, 12, 12, 7)