In [6]:
from dataclasses import dataclass
import flax.linen as nn
from flax import struct


def mish(x): return x * jnp.tanh(jax.nn.softplus(x))

@dataclass
class AZResnetConfig:
    model_type: str
    policy_head_out_size: int
    num_blocks: int
    num_channels: int

class ResidualBlock(nn.Module):
    channels: int

    @nn.compact
    def __call__(self, x, train: bool):
        y = nn.Conv(features=self.channels, kernel_size=(3,3), strides=(1,1), padding='SAME', use_bias=False)(x)
        y = nn.BatchNorm(use_running_average=not train)(y)
        y = mish(y)
        y = nn.Conv(features=self.channels, kernel_size=(3,3), strides=(1,1), padding='SAME', use_bias=False)(y)
        y = nn.BatchNorm(use_running_average=not train)(y)
        return mish(x + y)

class AZResnet(nn.Module):
    config: AZResnetConfig

    @nn.compact
    def __call__(self, x, train: bool):
        x = nn.Conv(features=self.config.num_channels, kernel_size=(1,1), strides=(1,1), padding='SAME', use_bias=False)(x)
        x = nn.BatchNorm(use_running_average=not train)(x)
        x = nn.relu(x)

        for _ in range(self.config.num_blocks):
            x = ResidualBlock(channels=self.config.num_channels)(x, train=train)

        # policy head
        policy = nn.Conv(features=2, kernel_size=(1,1), strides=(1,1), padding='SAME', use_bias=False)(x)
        policy = nn.BatchNorm(use_running_average=not train)(policy)
        policy = nn.relu(policy)
        policy = policy.reshape((policy.shape[0], -1))
        policy = nn.Dense(features=self.config.policy_head_out_size)(policy)

        # value head
        value = nn.Conv(features=1, kernel_size=(1,1), strides=(1,1), padding='SAME', use_bias=False)(x)
        value = nn.BatchNorm(use_running_average=not train)(value)
        value = nn.relu(value)
        value = value.reshape((value.shape[0], -1))
        value = nn.Dense(features=1)(value)
        value = nn.tanh(value)

        return policy, value

import jax
import jax.numpy as jnp 
model = AZResnet(AZResnetConfig(
    model_type="resnet",
    policy_head_out_size=4992,
    num_blocks=15,
    num_channels=4,
))
batch = jnp.ones((1024, 8, 16))
variables = model.init(jax.random.key(0), batch, train=False)
%timeit output = model.apply(variables, batch, train=False)

136 ms ± 2.07 ms per loop (mean ± std. dev. of 7 runs, 10 loops each)
