In [1]:
import os
os.environ["JAX_PLATFORM_NAME"] = "cpu"

import jax
import jax.numpy as jnp
import numpy as np
from jaxmao.layers import Conv2D, SimpleDense, Dense, BatchNorm, ReLU, Flatten, StableSoftmax, BatchNorm2D, DepthwiseConv2D, Activation
from jaxmao.modules import Module
from jaxmao.optimizers import GradientDescent
from jaxmao.losses import CategoricalCrossEntropy
from jaxmao.metrics import Accuracy, Precision, Recall

In [2]:
jax.devices()

I0000 00:00:1697866458.138192   57239 tfrt_cpu_pjrt_client.cc:349] TfrtCpuClient created.


[CpuDevice(id=0)]

# Dense

In [3]:
# Define the input shape and number of output neurons
input_shape = (4,)  # Input shape (number of input features)
output_neurons = 3  # Number of output neurons

# Initialize custom weights and bias
custom_weights = np.array([
    [0.1, 0.2, 0.3],
    [0.4, 0.5, 0.6],
    [0.7, 0.8, 0.9],
    [1.0, 1.1, 1.2]
])
custom_bias = np.array([0.01, 0.02, 0.03])

In [4]:
dense = SimpleDense(4, 3, use_bias=False)
dense.init_params(key=jax.random.PRNGKey(0))
# dense.params = dict()
# dense.params['weights'] = custom_weights
# dense.params['biases'] = custom_bias

bn = BatchNorm(3, momentum=0.5, axis_mean=0)
bn.init_params(None)
bn.set_inference_mode()


In [5]:
sample_input = np.arange(8*4).reshape(8, 4)  # Shape will be (1, 4)
output, _ = dense(dense.params, sample_input)
output.shape

(8, 3)

## Dense, bn

In [6]:
bn.eps = np.float32(1e-3) # this is what Keras use. I use 1e-5

output2, _ = bn(bn.params, output)
output2

Array([[0.        , 0.        , 0.70880806],
       [0.        , 0.        , 0.        ],
       [0.        , 0.        , 0.        ],
       [0.        , 0.        , 0.        ],
       [0.        , 0.        , 0.        ],
       [0.        , 0.        , 0.        ],
       [0.        , 0.        , 0.        ],
       [0.        , 0.        , 0.        ]], dtype=float32)

In [7]:
bn.params, bn.state

({'gamma': Array([1., 1., 1.], dtype=float32),
  'beta': Array([0., 0., 0.], dtype=float32)},
 {'running_mean': Array([0., 0., 0.], dtype=float32),
  'running_var': Array([1., 1., 1.], dtype=float32),
  'momentum': 0.5,
  'training': False})

In [8]:
output3 = ReLU()(np.array([[-5, 0, 5]]))
output3

Array([[0, 0, 5]], dtype=int32)

# Conv2D

In [9]:
sample_input = np.arange(25).reshape(1, 5, 5, 1).astype('float32')

#### valid padding

In [10]:
custom_kernel = np.array([[[[0.1]], [[0.2]], [[0.3]]],
                          [[[0.4]], [[0.5]], [[0.6]]],
                          [[[0.7]], [[0.8]], [[0.9]]]])

custom_bias = np.array([0.01])

conv = Conv2D(1, 1, (3,3), 1, padding='VALID')
conv.params = dict(weights=custom_kernel, biases=custom_bias)
# conv.params

In [11]:
conv_output, state = conv(conv.params, sample_input)
conv_output.shape, conv_output.ravel()

((1, 3, 3, 1),
 Array([36.609997, 41.109997, 45.609997, 59.109997, 63.609997, 68.11    ,
        81.61001 , 86.10999 , 90.61    ], dtype=float32))

#### same padding

In [12]:
custom_kernel = np.array([[[[0.1]], [[0.2]], [[0.3]]],
                          [[[0.4]], [[0.5]], [[0.6]]],
                          [[[0.7]], [[0.8]], [[0.9]]]])

custom_bias = np.array([0.01])

conv = Conv2D(1, 1, (3,3), 1, padding='SAME')
conv.params = dict(weights=custom_kernel, biases=custom_bias)
# conv.params

In [13]:
conv_output, state = conv(conv.params, sample_input)
conv_output.shape, conv_output.ravel()

((1, 5, 5, 1),
 Array([10.01    , 16.31    , 20.21    , 24.109999, 16.01    , 24.31    ,
        36.609997, 41.109997, 45.609997, 29.11    , 40.81    , 59.109997,
        63.609997, 68.11    , 42.609997, 57.31    , 81.61001 , 86.10999 ,
        90.61    , 56.110004, 30.410002, 41.51    , 43.609997, 45.71    ,
        26.81    ], dtype=float32))

# Conv2D, bn

In [14]:
sample_input = np.arange(25).reshape(1, 5, 5, 1).astype('float32')

#### valid padding

In [15]:
custom_kernel = np.array([[[[0.1]], [[0.2]], [[0.3]]],
                          [[[0.4]], [[0.5]], [[0.6]]],
                          [[[0.7]], [[0.8]], [[0.9]]]])

custom_bias = np.array([0.01])

conv = Conv2D(1, 1, (3,3), 1, padding='VALID')
conv.params = dict(weights=custom_kernel, biases=custom_bias)
# conv.params

In [16]:
conv_output, state = conv(conv.params, sample_input)
conv_output.shape, conv_output.ravel()

((1, 3, 3, 1),
 Array([36.609997, 41.109997, 45.609997, 59.109997, 63.609997, 68.11    ,
        81.61001 , 86.10999 , 90.61    ], dtype=float32))

In [17]:
convbn = BatchNorm(1, axis_mean=(0, 1, 2))
convbn.init_params()
convbn.set_inference_mode()

convbn.eps = np.float32(1e-3)

In [18]:
bn_out, state = convbn(convbn.params, conv_output)
bn_out.shape, bn_out.ravel()

((1, 3, 3, 1),
 Array([36.591705, 41.08946 , 45.58721 , 59.080467, 63.578217, 68.07597 ,
        81.56924 , 86.06697 , 90.564735], dtype=float32))

#### same padding

In [19]:
custom_kernel = np.array([[[[0.1]], [[0.2]], [[0.3]]],
                          [[[0.4]], [[0.5]], [[0.6]]],
                          [[[0.7]], [[0.8]], [[0.9]]]])

custom_bias = np.array([0.01])

conv = Conv2D(1, 1, (3,3), 1, padding='SAME')
conv.params = dict(weights=custom_kernel, biases=custom_bias)
# conv.params

In [20]:
conv_output, state = conv(conv.params, sample_input)
conv_output.shape, conv_output.ravel()

((1, 5, 5, 1),
 Array([10.01    , 16.31    , 20.21    , 24.109999, 16.01    , 24.31    ,
        36.609997, 41.109997, 45.609997, 29.11    , 40.81    , 59.109997,
        63.609997, 68.11    , 42.609997, 57.31    , 81.61001 , 86.10999 ,
        90.61    , 56.110004, 30.410002, 41.51    , 43.609997, 45.71    ,
        26.81    ], dtype=float32))

In [21]:
convbn = BatchNorm(1, axis_mean=(0, 1, 2))
convbn.init_params()
convbn.set_inference_mode()

convbn.eps = np.float32(1e-3)

In [22]:
bn_out, state = convbn(convbn.params, conv_output)
bn_out.shape, bn_out.ravel()

((1, 5, 5, 1),
 Array([10.004999, 16.301851, 20.199902, 24.097954, 16.002   , 24.297853,
        36.591705, 41.08946 , 45.58721 , 29.095457, 40.78961 , 59.080467,
        63.578217, 68.07597 , 42.58871 , 57.28137 , 81.56924 , 86.06697 ,
        90.564735, 56.081974, 30.39481 , 41.48926 , 43.58821 , 45.687164,
        26.796606], dtype=float32))

In [23]:
convbn = BatchNorm(1, axis_mean=(0, 1, 2))
convbn.init_params()
convbn.set_inference_mode()

convbn2d = BatchNorm2D(1)
convbn2d.init_params()
convbn2d.set_inference_mode()

In [24]:
convbn(convbn.params, bn_out)[0].ravel() == convbn2d(convbn2d.params, bn_out)[0].ravel()

Array([ True,  True,  True,  True,  True,  True,  True,  True,  True,
        True,  True,  True,  True,  True,  True,  True,  True,  True,
        True,  True,  True,  True,  True,  True,  True], dtype=bool)

# Depthwise conv2d

In [25]:
sample_input = np.arange(50).reshape(1, 5, 5, 2).astype('float32')

custom_depthwise_kernel = jnp.array([[[[0.1, 0.2]], [[0.3, 0.4]], [[0.5, 0.6]]],
                                    [[[0.7, 0.8]], [[0.9, 1.0]], [[1.1, 1.2]]],
                                    [[[1.3, 1.4]], [[1.5, 1.6]], [[1.7, 1.8]]]]).reshape(3, 3, 1, 2)

custom_bias = jnp.array([0.01, 0.02])

dwconv = DepthwiseConv2D(2, 1, (3, 3), 1, 'relu', padding='VALID')
dwconv.init_params(key=jax.random.PRNGKey(1))
# dwconv.params = dict()
# dwconv.params['weights'] = custom_depthwise_kernel
# dwconv.params['biases'] = custom_bias

In [26]:
dwconv.params['weights'].shape

(3, 3, 1, 2)

In [27]:
output, s = dwconv(dwconv.params, sample_input)
output.shape, output

((1, 3, 3, 2),
 Array([[[[0., 0.],
          [0., 0.],
          [0., 0.]],
 
         [[0., 0.],
          [0., 0.],
          [0., 0.]],
 
         [[0., 0.],
          [0., 0.],
          [0., 0.]]]], dtype=float32))

# sublayers

In [28]:
import jax
import numpy as np
from jaxmao.layers import Layer, Dense

class TwoDense(Layer):
    def __init__(self, in_channels, out_channels):
        super().__init__()
        self.layers = {
            'dense1': Dense(4, 4, activation=None, use_bias=False),
            'dense2': Dense(4, 5, activation=None, use_bias=False)
        }
    
    def forward(self, params, x, state=None):
        x, state = self.apply(params, x, 'dense1', state)
        x, state = self.apply(params, x, 'dense2', state)
        return x
dw_sep = TwoDense(4, 5)
dw_sep.init_params(key=jax.random.PRNGKey(0))
dw_sep.set_inference_mode()
dw_sep.set_training_mode()

In [29]:
sample = np.random.normal(-2, 1, (5, 4))
dw_sep.forward(dw_sep.params, sample, dw_sep.state)

Array([[ 0.18857716,  0.7015928 , -0.9163361 , -0.05308571,  0.2983981 ],
       [ 0.0421089 ,  0.0218102 ,  0.16859055,  0.03580065, -0.15294687],
       [-0.9547683 , -1.0994829 ,  0.1783854 , -0.6866428 , -0.58513   ],
       [ 1.7384334 ,  1.458357  ,  1.6878839 ,  1.8142594 ,  1.7112275 ],
       [-1.0143512 , -1.082277  , -1.1185237 , -1.1103315 , -1.2715489 ]],      dtype=float32)

In [30]:
print(dw_sep.summary)

layer                output shape         #'s params           #'s states          
dense1               (5, 4)               28                   0                   
dense2               (5, 5)               35                   0                   



In [31]:
dw_sep.layers['dense1'].shapes

{}

In [32]:
dw_sep.layers['dense1'].layers['dense/batch_norm'].shapes

{'gamma': (4,), 'beta': (4,)}

In [33]:
dw_sep.layers['dense1'].layers['dense/simple_dense'].shapes

{'weights': (4, 4), 'biases': (4,)}

In [34]:
dw_sep.layers['dense1'].layers['dense/simple_dense'].num_params

20

In [35]:
dw_sep.layers['dense1'].shapes

{}

# depthwise separable conv2d

In [36]:
import jax
import numpy as np
from jaxmao.layers import Layer, Dense, Conv2D, DepthwiseConv2D
from jaxmao.initializers import *

class DepthwiseSeparableConv2D(Layer):
    def __init__(
            self,
            in_channels,
            out_channels,
            depth_multiplier=1,
            activation='relu',
            padding='SAME',
            strides=(1, 1),
            dilation=(1, 1),
            use_bias=True,
            weights_initializer=HeNormal(),
            bias_initializer=HeNormal(),
            dtype=jnp.float32
                 ):
        super().__init__()
        self.layers = {
            'depthwise' : DepthwiseConv2D(
                                in_channels, depth_multiplier, 
                                kernel_size=(3, 3), strides=strides, 
                                activation=activation, padding=padding,
                                dilation=dilation, use_bias=use_bias,
                                weights_initializer=weights_initializer,
                                bias_initializer=bias_initializer,
                                dtype=dtype
                                          ),
            'pointwise' : Conv2D(
                            in_channels, out_channels, 
                            kernel_size=(1, 1), strides=(1, 1),
                            activation=activation,
                            dilation=dilation, use_bias=use_bias,
                            weights_initializer=weights_initializer,
                            bias_initializer=bias_initializer,
                            dtype=dtype
                                 )
        }
    
    def forward(self, params, x, state=None):
        x, state = self.apply(params, x, 'depthwise', state)
        x, state = self.apply(params, x, 'pointwise', state)
        return x, state

In [37]:
sample = np.random.normal(0, 1, (5, 12, 12, 5))

dwsep = DepthwiseSeparableConv2D(5, 7)
dwsep.init_params(key=jax.random.PRNGKey(4))

In [38]:
x, s = dwsep.forward(dwsep.params, sample, dwsep.state)
x.shape

(5, 12, 12, 7)

# dropout

In [39]:
from jaxmao.layers import Dropout

dropout = Dropout(jax.random.PRNGKey(0), 0.5)
dropout.forward = jax.vmap(dropout.forward, in_axes=(None, 0, None))

dropout.set_inference_mode()
dropout.set_training_mode()

In [40]:
sample_input = np.array([[0.1, 0.2, 0.3, 0.4, 0.5]])
out, s = dropout(dropout.params, sample_input)

In [41]:
sample_input

array([[0.1, 0.2, 0.3, 0.4, 0.5]])

In [42]:
out

Array([[0.2, 0.4, 0. , 0. , 1. ]], dtype=float32)

# pooling

In [56]:
from jax import vmap

sample_input = np.array([[[
    [1], [2], [3], [4],
    [5], [6], [7], [8],
    [9], [10], [11], [12],
    [13], [14], [15], [16]
]]]).reshape(4, 4, 1).astype('float32')

sample_input = np.arange(5*4*4*1).astype('float32').reshape(5, 4, 4, 1)

from jaxmao.layers import MaxPooling2D
max_pool = MaxPooling2D()
max_pool.forward = vmap(max_pool.forward)
out, s = max_pool(max_pool.params, sample_input)
out

Array([[[[ 5.],
         [ 7.]],

        [[13.],
         [15.]]],


       [[[21.],
         [23.]],

        [[29.],
         [31.]]],


       [[[37.],
         [39.]],

        [[45.],
         [47.]]],


       [[[53.],
         [55.]],

        [[61.],
         [63.]]],


       [[[69.],
         [71.]],

        [[77.],
         [79.]]]], dtype=float32)

In [57]:
from jax import lax
lax.

AttributeError: module 'jax.lax' has no attribute 'mean'

In [44]:
(2, 2) + (1, )

(2, 2, 1)

In [45]:
max_pool.kernel_size
max_pool.strides

(2, 2, 1)

In [46]:
# from jax import lax

# padding_config = [(0, 0)] * len(sample_input.shape)  # No padding

# z = lax.reduce_window(
#     sample_input,
#     init_value=jnp.finfo(jnp.float32).min,
#     computation=lax.max,
#     window_dimensions=(2, 2, 1),
#     window_strides=(2, 2, 1),
#     padding=padding_config
# )
# z

In [47]:
sample_input.shape

(4, 4, 1)

In [48]:
max_pool.reducing_fn

<function jax._src.lax.lax.max(x: 'ArrayLike', y: 'ArrayLike') -> 'Array'>