In [1]:
import os
os.environ["JAX_PLATFORM_NAME"] = "cpu"

import jax
import jax.numpy as jnp
import numpy as np
from jaxmao.layers import Conv2D, Dense, BatchNorm, ReLU, Flatten, StableSoftmax, BatchNorm2D, DepthwiseConv2D, Activation
from jaxmao.modules import Module
from jaxmao.optimizers import GradientDescent
from jaxmao.losses import CategoricalCrossEntropy
from jaxmao.metrics import Accuracy, Precision, Recall

In [2]:
jax.devices()

I0000 00:00:1697821211.521425  144504 tfrt_cpu_pjrt_client.cc:349] TfrtCpuClient created.


[CpuDevice(id=0)]

# Dense

In [3]:
# Define the input shape and number of output neurons
input_shape = (4,)  # Input shape (number of input features)
output_neurons = 3  # Number of output neurons

# Initialize custom weights and bias
custom_weights = np.array([
    [0.1, 0.2, 0.3],
    [0.4, 0.5, 0.6],
    [0.7, 0.8, 0.9],
    [1.0, 1.1, 1.2]
])
custom_bias = np.array([0.01, 0.02, 0.03])

In [4]:
dense = Dense(4, 3, use_bias=False)
dense.init_params(key=jax.random.PRNGKey(0))
# dense.params = dict()
# dense.params['weights'] = custom_weights
# dense.params['biases'] = custom_bias

bn = BatchNorm(3, momentum=0.5, axis_mean=0)
bn.init_params(None)
bn.set_inference_mode()

In [5]:
Activation(act='relu')

<jaxmao.layers.ReLU at 0x7f58b03d4550>

In [6]:
dense.params

{'dense/simple_dense': {'weights': Array([[-1.8649967 ,  0.27245682,  0.23650852],
         [-0.5385112 ,  0.25521564, -0.8271961 ],
         [ 0.1533223 ,  0.36550814,  0.7068316 ],
         [-0.79391825,  1.3884199 ,  0.07156017]], dtype=float32),
  'biases': Array([0., 0., 0.], dtype=float32)},
 'dense/batch_norm': {'gamma': Array([1., 1., 1.], dtype=float32),
  'beta': Array([0., 0., 0.], dtype=float32)}}

In [7]:
sample_input = np.arange(8*4).reshape(8, 4)  # Shape will be (1, 4)
output, _ = dense(dense.params, sample_input)
output.shape

(8, 3)

## Dense, bn

In [8]:
bn.eps = np.float32(1e-3) # this is what Keras use. I use 1e-5

output2, _ = bn(bn.params, output)
output2

Array([[1.526762  , 0.        , 0.        ],
       [1.0905442 , 0.        , 0.        ],
       [0.6543265 , 0.        , 0.        ],
       [0.21810873, 0.        , 0.        ],
       [0.        , 0.21810892, 0.21810862],
       [0.        , 0.6543268 , 0.65432507],
       [0.        , 1.0905443 , 1.090543  ],
       [0.        , 1.5267621 , 1.5267589 ]], dtype=float32)

In [9]:
bn.params, bn.state

({'gamma': Array([1., 1., 1.], dtype=float32),
  'beta': Array([0., 0., 0.], dtype=float32)},
 {'running_mean': Array([0., 0., 0.], dtype=float32),
  'running_var': Array([1., 1., 1.], dtype=float32),
  'momentum': 0.5,
  'training': False})

In [10]:
output3 = ReLU()(np.array([[-5, 0, 5]]))
output3

Array([[0, 0, 5]], dtype=int32)

# Conv2D

In [11]:
sample_input = np.arange(25).reshape(1, 5, 5, 1).astype('float32')

#### valid padding

In [12]:
custom_kernel = np.array([[[[0.1]], [[0.2]], [[0.3]]],
                          [[[0.4]], [[0.5]], [[0.6]]],
                          [[[0.7]], [[0.8]], [[0.9]]]])

custom_bias = np.array([0.01])

conv = Conv2D(1, 1, (3,3), 1, padding='VALID')
conv.params = dict(weights=custom_kernel, biases=custom_bias)
# conv.params

In [13]:
conv_output, state = conv(conv.params, sample_input)
conv_output.shape, conv_output.ravel()

((1, 3, 3, 1),
 Array([36.609997, 41.109997, 45.609997, 59.109997, 63.609997, 68.11    ,
        81.61001 , 86.10999 , 90.61    ], dtype=float32))

#### same padding

In [14]:
custom_kernel = np.array([[[[0.1]], [[0.2]], [[0.3]]],
                          [[[0.4]], [[0.5]], [[0.6]]],
                          [[[0.7]], [[0.8]], [[0.9]]]])

custom_bias = np.array([0.01])

conv = Conv2D(1, 1, (3,3), 1, padding='SAME')
conv.params = dict(weights=custom_kernel, biases=custom_bias)
# conv.params

In [15]:
conv_output, state = conv(conv.params, sample_input)
conv_output.shape, conv_output.ravel()

((1, 5, 5, 1),
 Array([10.01    , 16.31    , 20.21    , 24.109999, 16.01    , 24.31    ,
        36.609997, 41.109997, 45.609997, 29.11    , 40.81    , 59.109997,
        63.609997, 68.11    , 42.609997, 57.31    , 81.61001 , 86.10999 ,
        90.61    , 56.110004, 30.410002, 41.51    , 43.609997, 45.71    ,
        26.81    ], dtype=float32))

# Conv2D, bn

In [16]:
sample_input = np.arange(25).reshape(1, 5, 5, 1).astype('float32')

#### valid padding

In [17]:
custom_kernel = np.array([[[[0.1]], [[0.2]], [[0.3]]],
                          [[[0.4]], [[0.5]], [[0.6]]],
                          [[[0.7]], [[0.8]], [[0.9]]]])

custom_bias = np.array([0.01])

conv = Conv2D(1, 1, (3,3), 1, padding='VALID')
conv.params = dict(weights=custom_kernel, biases=custom_bias)
# conv.params

In [18]:
conv_output, state = conv(conv.params, sample_input)
conv_output.shape, conv_output.ravel()

((1, 3, 3, 1),
 Array([36.609997, 41.109997, 45.609997, 59.109997, 63.609997, 68.11    ,
        81.61001 , 86.10999 , 90.61    ], dtype=float32))

In [19]:
convbn = BatchNorm(1, axis_mean=(0, 1, 2))
convbn.init_params()
convbn.set_inference_mode()

convbn.eps = np.float32(1e-3)

In [20]:
bn_out, state = convbn(convbn.params, conv_output)
bn_out.shape, bn_out.ravel()

((1, 3, 3, 1),
 Array([36.591705, 41.08946 , 45.58721 , 59.080467, 63.578217, 68.07597 ,
        81.56924 , 86.06697 , 90.564735], dtype=float32))

#### same padding

In [21]:
custom_kernel = np.array([[[[0.1]], [[0.2]], [[0.3]]],
                          [[[0.4]], [[0.5]], [[0.6]]],
                          [[[0.7]], [[0.8]], [[0.9]]]])

custom_bias = np.array([0.01])

conv = Conv2D(1, 1, (3,3), 1, padding='SAME')
conv.params = dict(weights=custom_kernel, biases=custom_bias)
# conv.params

In [22]:
conv_output, state = conv(conv.params, sample_input)
conv_output.shape, conv_output.ravel()

((1, 5, 5, 1),
 Array([10.01    , 16.31    , 20.21    , 24.109999, 16.01    , 24.31    ,
        36.609997, 41.109997, 45.609997, 29.11    , 40.81    , 59.109997,
        63.609997, 68.11    , 42.609997, 57.31    , 81.61001 , 86.10999 ,
        90.61    , 56.110004, 30.410002, 41.51    , 43.609997, 45.71    ,
        26.81    ], dtype=float32))

In [23]:
convbn = BatchNorm(1, axis_mean=(0, 1, 2))
convbn.init_params()
convbn.set_inference_mode()

convbn.eps = np.float32(1e-3)

In [24]:
bn_out, state = convbn(convbn.params, conv_output)
bn_out.shape, bn_out.ravel()

((1, 5, 5, 1),
 Array([10.004999, 16.301851, 20.199902, 24.097954, 16.002   , 24.297853,
        36.591705, 41.08946 , 45.58721 , 29.095457, 40.78961 , 59.080467,
        63.578217, 68.07597 , 42.58871 , 57.28137 , 81.56924 , 86.06697 ,
        90.564735, 56.081974, 30.39481 , 41.48926 , 43.58821 , 45.687164,
        26.796606], dtype=float32))

In [25]:
convbn = BatchNorm(1, axis_mean=(0, 1, 2))
convbn.init_params()
convbn.set_inference_mode()

convbn2d = BatchNorm2D(1)
convbn2d.init_params()
convbn2d.set_inference_mode()

In [26]:
convbn(convbn.params, bn_out)[0].ravel() == convbn2d(convbn2d.params, bn_out)[0].ravel()

Array([ True,  True,  True,  True,  True,  True,  True,  True,  True,
        True,  True,  True,  True,  True,  True,  True,  True,  True,
        True,  True,  True,  True,  True,  True,  True], dtype=bool)

# Depthwise conv2d

In [27]:
sample_input = np.arange(50).reshape(1, 5, 5, 2).astype('float32')

custom_depthwise_kernel = jnp.array([[[[0.1, 0.2]], [[0.3, 0.4]], [[0.5, 0.6]]],
                                    [[[0.7, 0.8]], [[0.9, 1.0]], [[1.1, 1.2]]],
                                    [[[1.3, 1.4]], [[1.5, 1.6]], [[1.7, 1.8]]]]).reshape(3, 3, 1, 2)

custom_bias = jnp.array([0.01, 0.02])

dwconv = DepthwiseConv2D(2, 1, (3, 3), 1, 'relu', padding='VALID')
dwconv.init_params(key=jax.random.PRNGKey(1))
# dwconv.params = dict()
# dwconv.params['weights'] = custom_depthwise_kernel
# dwconv.params['biases'] = custom_bias

In [28]:
dwconv.params['weights'].shape

(3, 3, 1, 2)

In [29]:
output, s = dwconv(dwconv.params, sample_input)
output.shape, output

((1, 3, 3, 2),
 Array([[[[0., 0.],
          [0., 0.],
          [0., 0.]],
 
         [[0., 0.],
          [0., 0.],
          [0., 0.]],
 
         [[0., 0.],
          [0., 0.],
          [0., 0.]]]], dtype=float32))

# sublayers

In [30]:
import jax
import numpy as np
from jaxmao.layers import Layer, Dense

class TwoDense(Layer):
    def __init__(self, in_channels, out_channels):
        super().__init__()
        self.layers = {
            'dense1': Dense(4, 4, activation=None, use_bias=False),
            'dense2': Dense(4, 5, activation=None, use_bias=False)
        }
    
    def forward(self, params, x, state=None):
        x, state = self.apply(params, x, 'dense1', state)
        x, state = self.apply(params, x, 'dense2', state)
        return x
dw_sep = TwoDense(4, 5)
dw_sep.init_params(key=jax.random.PRNGKey(0))

In [31]:
sample = np.random.normal(-2, 1, (5, 4))
dw_sep.forward(dw_sep.params, sample, dw_sep.state)

Array([[-1.7300946e+00,  1.6384195e+00,  1.8752322e+00, -1.7356913e+00,
        -7.2044253e-01],
       [ 6.1868316e-01, -1.1059952e+00, -5.1548386e-01,  4.7545135e-01,
        -1.3730115e+00],
       [ 1.2308455e+00, -9.6424913e-01, -4.2155027e-01,  1.3255204e+00,
         1.7722866e-01],
       [ 1.6596594e-01,  2.9860577e-02, -1.0167825e+00, -2.3425558e-04,
         3.5464722e-01],
       [-2.8539994e-01,  4.0196428e-01,  7.8584462e-02, -6.5046392e-02,
         1.5615780e+00]], dtype=float32)

In [32]:
print(dw_sep.summary)

layer                output shape         #'s params           #'s states          
dense1               (5, 4)               0                    0                   
dense2               (5, 5)               0                    0                   



# depthwise separable conv2d

In [33]:
import jax
import numpy as np
from jaxmao.layers import Layer, Dense, Conv2D, DepthwiseConv2D
from jaxmao.initializers import *

class DepthwiseSeparableConv2D(Layer):
    def __init__(
            self,
            in_channels,
            out_channels,
            depth_multiplier=1,
            activation='relu',
            padding='SAME',
            strides=(1, 1),
            dilation=(1, 1),
            use_bias=True,
            weights_initializer=HeNormal(),
            bias_initializer=HeNormal(),
            dtype=jnp.float32
                 ):
        super().__init__()
        self.layers = {
            'depthwise' : DepthwiseConv2D(
                                in_channels, depth_multiplier, 
                                kernel_size=(3, 3), strides=strides, 
                                activation=activation, padding=padding,
                                dilation=dilation, use_bias=use_bias,
                                weights_initializer=weights_initializer,
                                bias_initializer=bias_initializer,
                                dtype=dtype
                                          ),
            'pointwise' : Conv2D(
                            in_channels, out_channels, 
                            kernel_size=(1, 1), strides=(1, 1),
                            activation=activation,
                            dilation=dilation, use_bias=use_bias,
                            weights_initializer=weights_initializer,
                            bias_initializer=bias_initializer,
                            dtype=dtype
                                 )
        }
    
    def forward(self, params, x, state=None):
        x, state = self.apply(params, x, 'depthwise', state)
        x, state = self.apply(params, x, 'pointwise', state)
        return x, state

In [34]:
sample = np.random.normal(0, 1, (5, 12, 12, 5))

dwsep = DepthwiseSeparableConv2D(5, 7)
dwsep.init_params(key=jax.random.PRNGKey(4))

In [35]:
x, s = dwsep.forward(dwsep.params, sample, dwsep.state)
x.shape

(5, 12, 12, 7)