In [9]:
from dataclasses import dataclass
import flax.linen as nn
from flax import struct


def mish(x): return x * jnp.tanh(jax.nn.softplus(x))

@dataclass
class AZResnetConfig:
    num_blocks: int
    channels: int
    policy_channels: int
    value_channels: int
    num_policy_labels: int

class ResidualBlock(nn.Module):
    channels: int
    se: bool
    se_ratio: int = 4

    @nn.compact
    def __call__(self, x, train: bool):
        if self.se: 
            squeeze = jnp.mean(x, axis=(1, 2), keepdims=True)

            excitation = nn.Dense(features=self.channels // self.se_ratio, use_bias=True)(squeeze)
            excitation = nn.relu(excitation)
            excitation = nn.Dense(features=self.channels, use_bias=True)(excitation)
            excitation = nn.hard_sigmoid(excitation)

            y = x * excitation

        y = nn.Conv(features=self.channels, kernel_size=(3, 3), padding=(1, 1), use_bias=False)(x)
        y = nn.BatchNorm(use_running_average=not train)(y)
        y = mish(y)
        y = nn.Conv(features=self.channels, kernel_size=(3, 3), padding=(1, 1), use_bias=False)(x)
        y = nn.BatchNorm(use_running_average=not train)(y)
        y = mish(y)
        return x + y

class AZResnet(nn.Module):
    config: AZResnetConfig

    @nn.compact
    def __call__(self, x, train: bool):
        x = nn.Conv(features=self.config.channels, kernel_size=(3, 3), padding=(1, 1), use_bias=False)(x)
        x = nn.BatchNorm(use_running_average=not train)(x)
        x = mish(x)

        for _ in range(self.config.num_blocks):
            x = ResidualBlock(channels=self.config.channels, se=True)(x, train=train)

        # policy head
        policy = [None, None]
        policy[0] = nn.Conv(features=self.config.channels, kernel_size=(3, 3), padding=(1, 1), use_bias=False)(x)
        policy[0] = nn.BatchNorm(use_running_average=not train)(policy[0])
        policy[0] = mish(policy[0])
        policy[0] = nn.Conv(features=self.config.policy_channels, kernel_size=(3, 3), padding=(1, 1), use_bias=False)(policy[0])
        policy[0] = nn.BatchNorm(use_running_average=not train)(policy[0])
        policy[0] = mish(policy[0])
        policy[0] = policy[0].reshape((policy[0].shape[0], -1))
        policy[0] = nn.Dense(features=self.config.num_policy_labels)(policy[0])

        policy[1] = nn.Conv(features=self.config.channels, kernel_size=(3, 3), padding=(1, 1), use_bias=False)(x)
        policy[1] = nn.BatchNorm(use_running_average=not train)(policy[1])
        policy[1] = mish(policy[1])
        policy[1] = nn.Conv(features=self.config.policy_channels, kernel_size=(3, 3), padding=(1, 1), use_bias=False)(policy[1])
        policy[1] = nn.BatchNorm(use_running_average=not train)(policy[1])
        policy[1] = mish(policy[1])
        policy[1] = policy[1].reshape((policy[1].shape[0], -1))
        policy[1] = nn.Dense(features=self.config.num_policy_labels)(policy[1])

        # value head
        value = nn.Conv(features=self.config.value_channels, kernel_size=(1, 1), use_bias=False)(x)
        value = nn.BatchNorm(use_running_average=not train)(value)
        value = mish(value)
        value = value.reshape((value.shape[0], -1))
        value = nn.Dense(features=256)(value)
        value = mish(value)
        value = nn.Dense(features=1)(value)
        value = nn.tanh(value)

        return policy, value

import jax
import jax.numpy as jnp 
model = AZResnet(AZResnetConfig(
    num_blocks=15,
    channels=256,
    policy_channels=4, 
    value_channels=8,
    num_policy_labels=4992,
))
batch = jnp.ones((1024, 8, 16, 34))
variables = model.init(jax.random.key(0), batch, train=False)
print(variables)
%timeit output = model.apply(variables, jnp.ones((29, 8, 16, 34)), train=True, mutable=['batch_stats'])
#print(output[0].shape)

{'params': {'Conv_0': {'kernel': Array([[[[-8.09571054e-03, -1.59240719e-02, -8.22316762e-03, ...,
          -1.82606392e-02,  4.99152020e-02,  5.23208603e-02],
         [ 9.00918022e-02,  1.30150560e-02,  1.00492738e-01, ...,
           8.98724794e-02,  5.90446629e-02, -7.87667260e-02],
         [ 4.96587437e-03, -3.52317765e-02,  7.62461126e-02, ...,
           1.16681099e-01,  1.24644963e-02,  1.15168676e-01],
         ...,
         [ 4.07651402e-02,  4.74572740e-02,  4.92565706e-02, ...,
          -5.37220351e-02, -7.39137903e-02, -7.60867679e-03],
         [-5.32271340e-02, -4.84290048e-02, -3.24185975e-02, ...,
           9.34603736e-02, -4.42188466e-03, -6.25795918e-03],
         [-3.89340520e-02, -3.60155664e-02,  9.68061015e-02, ...,
           3.59869041e-02,  2.40581967e-02,  8.02789256e-03]],

        [[ 1.61426924e-02, -2.97888536e-02,  5.38754091e-02, ...,
           7.51366317e-02,  4.34052683e-02,  4.14196625e-02],
         [ 7.75327012e-02, -6.54996559e-02,  6.61575049

In [None]:
import torch
from torch.nn import Sequential, Conv2d, BatchNorm2d, ReLU

class _Stem(torch.nn.Module):
    def __init__(self, channels,  act_type="relu", nb_input_channels=34):
        """
        Definition of the stem proposed by the alpha zero authors
        :param channels: Number of channels for 1st conv operation
        :param act_type: Activation type to use
        :param nb_input_channels: Number of input channels of the board representation
        """

        super(_Stem, self).__init__()

        self.body = Sequential(
            Conv2d(in_channels=nb_input_channels, out_channels=channels, kernel_size=(3, 3), padding=(1, 1),
                   bias=False),
            BatchNorm2d(num_features=channels),
            ReLU(inplace=True))

    def forward(self, x):
        """
        Compute forward pass
        :param F: Handle
        :param x: Input data to the block
        :return: Activation maps of the block
        """
        return self.body(x)

model = _Stem()