In [1]:
import os
os.environ["JAX_PLATFORM_NAME"] = "cpu"
os.environ['CUDA_VISIBLE_DEVICES'] = '-1'

In [10]:
import jax
import jax.numpy as jnp
import numpy as np
from jaxmao.layers import Conv2D, SimpleDense, Dense, BatchNorm, ReLU, Flatten, StableSoftmax, BatchNorm2D, DepthwiseConv2D, Activation
from jaxmao.modules import Module
from jaxmao.optimizers import GradientDescent
from jaxmao.losses import CategoricalCrossEntropy
from jaxmao.metrics import Accuracy, Precision, Recall

print('jax.devices() :', jax.devices())

import tensorflow as tf
print('tf.config.list_physical_devices(): ', tf.config.list_physical_devices())

jax.devices() : [CpuDevice(id=0)]
tf.config.list_physical_devices():  [PhysicalDevice(name='/physical_device:CPU:0', device_type='CPU')]


# Dense

In [3]:
# Define the input shape and number of output neurons
input_shape = (4,)  # Input shape (number of input features)
output_neurons = 3  # Number of output neurons

# Initialize custom weights and bias
custom_weights = np.array([
    [0.1, 0.2, 0.3],
    [0.4, 0.5, 0.6],
    [0.7, 0.8, 0.9],
    [1.0, 1.1, 1.2]
])
custom_bias = np.array([0.01, 0.02, 0.03])

In [4]:
dense = SimpleDense(4, 3, use_bias=False)
dense.init_params(key=jax.random.PRNGKey(0))
# dense.params = dict()
# dense.params['weights'] = custom_weights
# dense.params['biases'] = custom_bias

bn = BatchNorm(3, momentum=0.5, axis_mean=0)
bn.init_params(None)
bn.set_inference_mode()


In [5]:
sample_input = np.arange(8*4).reshape(8, 4)  # Shape will be (1, 4)
output, _ = dense(dense.params, sample_input)
output.shape

(8, 3)

## Dense, bn

In [6]:
bn.eps = np.float32(1e-3) # this is what Keras use. I use 1e-5

output2, _ = bn(bn.params, output)
output2

Array([[0.        , 0.        , 0.70880806],
       [0.        , 0.        , 0.        ],
       [0.        , 0.        , 0.        ],
       [0.        , 0.        , 0.        ],
       [0.        , 0.        , 0.        ],
       [0.        , 0.        , 0.        ],
       [0.        , 0.        , 0.        ],
       [0.        , 0.        , 0.        ]], dtype=float32)

In [7]:
bn.params, bn.state

({'gamma': Array([1., 1., 1.], dtype=float32),
  'beta': Array([0., 0., 0.], dtype=float32)},
 {'running_mean': Array([0., 0., 0.], dtype=float32),
  'running_var': Array([1., 1., 1.], dtype=float32),
  'momentum': 0.5,
  'training': False})

In [8]:
output3 = ReLU()(np.array([[-5, 0, 5]]))
output3

Array([[0, 0, 5]], dtype=int32)

# Conv2D

In [38]:
sample_input = np.arange(25).reshape(1, 5, 5, 1).astype('float32')

#### valid padding

In [42]:
custom_kernel = np.array([[[[0.1]], [[0.2]], [[0.3]]],
                          [[[0.4]], [[0.5]], [[0.6]]],
                          [[[0.7]], [[0.8]], [[0.9]]]])

custom_bias = np.array([0.01])

conv = Conv2D(1, 1, (3,3), 1, padding='VALID', batch_norm=False)
conv.params = dict(weights=custom_kernel, biases=custom_bias)
# conv.params

In [43]:
conv_output, state = conv(conv.params, sample_input)
conv_output.shape, conv_output.ravel()

((1, 3, 3, 1),
 Array([36.609997, 41.109997, 45.609997, 59.109997, 63.609997, 68.11    ,
        81.61001 , 86.10999 , 90.61    ], dtype=float32))

#### same padding

In [13]:
custom_kernel = np.array([[[[0.1]], [[0.2]], [[0.3]]],
                          [[[0.4]], [[0.5]], [[0.6]]],
                          [[[0.7]], [[0.8]], [[0.9]]]])

custom_bias = np.array([0.01])

conv = Conv2D(1, 1, (3,3), 1, padding='SAME')
conv.params = dict(weights=custom_kernel, biases=custom_bias)
# conv.params

In [14]:
conv_output, state = conv(conv.params, sample_input)
conv_output.shape, conv_output.ravel()

((1, 5, 5, 1),
 Array([10.01    , 16.31    , 20.21    , 24.109999, 16.01    , 24.31    ,
        36.609997, 41.109997, 45.609997, 29.11    , 40.81    , 59.109997,
        63.609997, 68.11    , 42.609997, 57.31    , 81.61001 , 86.10999 ,
        90.61    , 56.110004, 30.410002, 41.51    , 43.609997, 45.71    ,
        26.81    ], dtype=float32))

# Conv2D, bn

In [15]:
sample_input = np.arange(25).reshape(1, 5, 5, 1).astype('float32')

#### valid padding

In [16]:
custom_kernel = np.array([[[[0.1]], [[0.2]], [[0.3]]],
                          [[[0.4]], [[0.5]], [[0.6]]],
                          [[[0.7]], [[0.8]], [[0.9]]]])

custom_bias = np.array([0.01])

conv = Conv2D(1, 1, (3,3), 1, padding='VALID')
conv.params = dict(weights=custom_kernel, biases=custom_bias)
# conv.params

In [17]:
conv_output, state = conv(conv.params, sample_input)
conv_output.shape, conv_output.ravel()

((1, 3, 3, 1),
 Array([36.609997, 41.109997, 45.609997, 59.109997, 63.609997, 68.11    ,
        81.61001 , 86.10999 , 90.61    ], dtype=float32))

In [18]:
convbn = BatchNorm(1, axis_mean=(0, 1, 2))
convbn.init_params()
convbn.set_inference_mode()

convbn.eps = np.float32(1e-3)

In [19]:
bn_out, state = convbn(convbn.params, conv_output)
bn_out.shape, bn_out.ravel()

((1, 3, 3, 1),
 Array([36.591705, 41.08946 , 45.58721 , 59.080467, 63.578217, 68.07597 ,
        81.56924 , 86.06697 , 90.564735], dtype=float32))

#### same padding

In [20]:
custom_kernel = np.array([[[[0.1]], [[0.2]], [[0.3]]],
                          [[[0.4]], [[0.5]], [[0.6]]],
                          [[[0.7]], [[0.8]], [[0.9]]]])

custom_bias = np.array([0.01])

conv = Conv2D(1, 1, (3,3), 1, padding='SAME')
conv.params = dict(weights=custom_kernel, biases=custom_bias)
# conv.params

In [21]:
conv_output, state = conv(conv.params, sample_input)
conv_output.shape, conv_output.ravel()

((1, 5, 5, 1),
 Array([10.01    , 16.31    , 20.21    , 24.109999, 16.01    , 24.31    ,
        36.609997, 41.109997, 45.609997, 29.11    , 40.81    , 59.109997,
        63.609997, 68.11    , 42.609997, 57.31    , 81.61001 , 86.10999 ,
        90.61    , 56.110004, 30.410002, 41.51    , 43.609997, 45.71    ,
        26.81    ], dtype=float32))

In [22]:
convbn = BatchNorm(1, axis_mean=(0, 1, 2))
convbn.init_params()
convbn.set_inference_mode()

convbn.eps = np.float32(1e-3)

In [23]:
bn_out, state = convbn(convbn.params, conv_output)
bn_out.shape, bn_out.ravel()

((1, 5, 5, 1),
 Array([10.004999, 16.301851, 20.199902, 24.097954, 16.002   , 24.297853,
        36.591705, 41.08946 , 45.58721 , 29.095457, 40.78961 , 59.080467,
        63.578217, 68.07597 , 42.58871 , 57.28137 , 81.56924 , 86.06697 ,
        90.564735, 56.081974, 30.39481 , 41.48926 , 43.58821 , 45.687164,
        26.796606], dtype=float32))

In [24]:
convbn = BatchNorm(1, axis_mean=(0, 1, 2))
convbn.init_params()
convbn.set_inference_mode()

convbn2d = BatchNorm2D(1)
convbn2d.init_params()
convbn2d.set_inference_mode()

In [25]:
convbn(convbn.params, bn_out)[0].ravel() == convbn2d(convbn2d.params, bn_out)[0].ravel()

Array([ True,  True,  True,  True,  True,  True,  True,  True,  True,
        True,  True,  True,  True,  True,  True,  True,  True,  True,
        True,  True,  True,  True,  True,  True,  True], dtype=bool)

# Depthwise conv2d

In [26]:
sample_input = np.arange(50).reshape(1, 5, 5, 2).astype('float32')

custom_depthwise_kernel = jnp.array([[[[0.1, 0.2]], [[0.3, 0.4]], [[0.5, 0.6]]],
                                    [[[0.7, 0.8]], [[0.9, 1.0]], [[1.1, 1.2]]],
                                    [[[1.3, 1.4]], [[1.5, 1.6]], [[1.7, 1.8]]]]).reshape(3, 3, 1, 2)

custom_bias = jnp.array([0.01, 0.02])

dwconv = DepthwiseConv2D(2, 1, (3, 3), 1, 'relu', padding='VALID')
dwconv.init_params(key=jax.random.PRNGKey(1))
dwconv.params = dict()
dwconv.params['weights'] = custom_depthwise_kernel
dwconv.params['biases'] = custom_bias

In [27]:
dwconv.params['weights'].shape

(3, 3, 1, 2)

In [28]:
output, s = dwconv(dwconv.params, sample_input)
output.shape, output

((1, 3, 3, 2),
 Array([[[[135.61   , 155.42   ],
          [151.81   , 173.42   ],
          [168.01   , 191.42   ]],
 
         [[216.61   , 245.42   ],
          [232.81   , 263.41998],
          [249.01   , 281.42   ]],
 
         [[297.61   , 335.41998],
          [313.81003, 353.41998],
          [330.01   , 371.42   ]]]], dtype=float32))

# sublayers

In [29]:
import jax
import numpy as np
from jaxmao.layers import Layer, Dense

class TwoDense(Layer):
    def __init__(self, in_channels, out_channels):
        super().__init__()
        self.layers = {
            'dense1': Dense(4, 4, activation=None, use_bias=False),
            'dense2': Dense(4, 5, activation=None, use_bias=False)
        }
    
    def forward(self, params, x, state=None):
        x, state = self.apply(params, x, 'dense1', state)
        x, state = self.apply(params, x, 'dense2', state)
        return x
dw_sep = TwoDense(4, 5)
dw_sep.init_params(key=jax.random.PRNGKey(0))
dw_sep.set_inference_mode()
dw_sep.set_training_mode()

In [30]:
sample = np.random.normal(-2, 1, (5, 4))
dw_sep.forward(dw_sep.params, sample, dw_sep.state)

Array([[-5.005852  ,  2.9989111 ,  6.408935  , -2.4594626 , -0.9159826 ],
       [-0.3220476 ,  4.025765  ,  5.8815145 , -0.08278608,  0.47821122],
       [-4.0098877 ,  0.2605685 ,  4.0933294 , -2.0196974 , -2.928052  ],
       [-4.5613813 ,  3.0746803 ,  6.318892  , -2.298581  , -1.1305283 ],
       [-4.8671975 ,  5.292586  ,  6.9901257 , -2.7495313 , -0.35684916]],      dtype=float32)

In [31]:
print(dw_sep.summary)

layer                output shape         #'s params           #'s states          
dense1               (5, 4)               20                   0                   
dense2               (5, 5)               25                   0                   



In [32]:
dw_sep.layers['dense1'].shapes

{}

In [33]:
dw_sep.layers['dense1'].layers['dense/batch_norm'].shapes

KeyError: 'dense/batch_norm'

In [None]:
dw_sep.layers['dense1'].layers['dense/simple_dense'].shapes

{'weights': (4, 4), 'biases': (4,)}

In [None]:
dw_sep.layers['dense1'].layers['dense/simple_dense'].num_params

20

In [None]:
dw_sep.layers['dense1'].shapes

{}

# depthwise separable conv2d

In [None]:
import jax
import numpy as np
from jaxmao.layers import Layer, Dense, Conv2D, DepthwiseConv2D
from jaxmao.initializers import *

class DepthwiseSeparableConv2D(Layer):
    def __init__(
            self,
            in_channels,
            out_channels,
            depth_multiplier=1,
            activation='relu',
            padding='SAME',
            strides=(1, 1),
            dilation=(1, 1),
            use_bias=True,
            weights_initializer=HeNormal(),
            bias_initializer=HeNormal(),
            dtype=jnp.float32
                 ):
        super().__init__()
        self.layers = {
            'depthwise' : DepthwiseConv2D(
                                in_channels, depth_multiplier, 
                                kernel_size=(3, 3), strides=strides, 
                                activation=activation, padding=padding,
                                dilation=dilation, use_bias=use_bias,
                                weights_initializer=weights_initializer,
                                bias_initializer=bias_initializer,
                                dtype=dtype
                                          ),
            'pointwise' : Conv2D(
                            in_channels, out_channels, 
                            kernel_size=(1, 1), strides=(1, 1),
                            activation=activation,
                            dilation=dilation, use_bias=use_bias,
                            weights_initializer=weights_initializer,
                            bias_initializer=bias_initializer,
                            dtype=dtype
                                 )
        }
    
    def forward(self, params, x, state=None):
        x, state = self.apply(params, x, 'depthwise', state)
        x, state = self.apply(params, x, 'pointwise', state)
        return x, state

In [None]:
sample = np.random.normal(0, 1, (5, 12, 12, 5))

dwsep = DepthwiseSeparableConv2D(5, 7)
dwsep.init_params(key=jax.random.PRNGKey(4))

In [None]:
x, s = dwsep.forward(dwsep.params, sample, dwsep.state)
x.shape

(5, 12, 12, 7)

# dropout

In [None]:
from jaxmao.layers import Dropout

dropout = Dropout(jax.random.PRNGKey(0), 0.5)
dropout.forward = jax.vmap(dropout.forward, in_axes=(None, 0, None))

dropout.set_inference_mode()
dropout.set_training_mode()

In [None]:
sample_input = np.array([[0.1, 0.2, 0.3, 0.4, 0.5]])
out, s = dropout(dropout.params, sample_input)

In [None]:
sample_input

array([[0.1, 0.2, 0.3, 0.4, 0.5]])

In [None]:
out

Array([[0.2, 0.4, 0. , 0. , 1. ]], dtype=float32)

# pooling

In [None]:
from jax import vmap
import numpy as np

sample_input = np.array([[[
    [1], [2], [3], [4],
    [5], [6], [7], [8],
    [9], [10], [11], [12],
    [13], [14], [15], [16]
]]]).reshape(1, 4, 4, 1).astype('float32')

sample_input = np.arange(5*4*4*1).astype('float32').reshape(5, 4, 4, 1)

from jaxmao.layers import MaxPooling2D
max_pool = MaxPooling2D()
max_pool.reducing_fn = reduce_max
out, s = max_pool(max_pool.params, sample_input)
out

NameError: name 'reduce_max' is not defined

In [None]:
import numpy as np
sample_input = np.array([[[
    [1], [2], [3], [4],
    [5], [6], [7], [8],
    [9], [10], [11], [12],
    [13], [14], [15], [16]
]]]).reshape(1, 4, 4, 1).astype('float32')

# sample_input = np.arange(5*4*4*1).astype('float32').reshape(5, 4, 4, 1)

from jax import vmap, lax
from jaxmao.layers import AveragePooling2D, MaxPooling2D
import jax.numpy as jnp

def rmean(a, b):
    return lax.add(a, b)

avg_pool = AveragePooling2D()
# avg_pool.reducing_fn = lax.add
# avg_pool.forward = vmap(avg_pool._pool_forward, in_axes=(None, 0, None))
out, s = avg_pool(avg_pool.params, sample_input)
out.shape, out

((1, 2, 2, 1),
 Array([[[[ 3.5],
          [ 5.5]],
 
         [[11.5],
          [13.5]]]], dtype=float32))

# average pooling

In [None]:
from jax import vmap
import numpy as np
sample_input = np.array([[
    [[1, 2, 3], [4, 5, 6], [7, 8, 9], [10, 11, 12]],
    [[13, 14, 15], [16, 17, 18], [19, 20, 21], [22, 23, 24]],
    [[25, 26, 27], [28, 29, 30], [31, 32, 33], [34, 35, 36]],
    [[37, 38, 39], [40, 41, 42], [43, 44, 45], [46, 47, 48]]
]])

from jaxmao.layers import GlobalAveragePooling2D
gap = GlobalAveragePooling2D()
out, s = gap(gap.params, sample_input)
out.shape, out

((1, 3), Array([[23.5, 24.5, 25.5]], dtype=float32))

In [None]:
from jax import vmap
import numpy as np
sample_input = np.array([[
    [[1, 2, 3], [4, 5, 6], [7, 8, 9], [10, 11, 12]],
    [[13, 14, 15], [16, 17, 18], [19, 20, 21], [22, 23, 24]],
    [[25, 26, 27], [28, 29, 30], [31, 32, 33], [34, 35, 36]],
    [[37, 38, 39], [40, 41, 42], [43, 44, 45], [46, 47, 48]]
]])

from jaxmao.layers import GlobalMaxPooling2D
gmp = GlobalMaxPooling2D()
# gmp.forward = vmap(gmp.forward, in_axes=(None, 0, None))
out, s = gmp(gmp.params, sample_input)
out.shape, out

((1, 3), Array([[46, 47, 48]], dtype=int32))