In [None]:
import os
os.environ['XLA_FLAGS'] = '--xla_force_host_platform_device_count=4'

In [None]:
import tensorflow as tf
import tensorflow_datasets as tfds
train_ds = tfds.load('mnist', split='train')
test_ds = tfds.load('mnist', split='test')

def data_normalize(ds):
    return ds.map(lambda sample: {
        'image': tf.cast(sample['image'], tf.float32) / 255.,
        'label': sample['label']
    })

train_ds = data_normalize(train_ds).shuffle(buffer_size=10, seed=42).batch(100).prefetch(1).take(1000)
test_ds = data_normalize(test_ds).shuffle(buffer_size=10, seed=42).batch(100).prefetch(1).take(1000)

total_batch = train_ds.cardinality().numpy()
total_tbatch = test_ds.cardinality().numpy()

In [None]:
from model.resnet_v2 import *

for batch in train_ds.as_numpy_iterator():
    x = batch['image']
    y = batch['label']

resnet = ResNet(num_classes=10, act_fn=nn.relu, block_class=ResNetBlock)
variables = initialize(resnet, 42, x)

In [None]:
import numpy as np
import jax
import jax.numpy as jnp
import flax
import flax.linen as nn

resnet_kernel_init = jax.nn.initializers.variance_scaling(2.0, mode='fan_out', distribution='normal')


class ResNetBlock(nn.Module):
    act_fn : callable  # Activation function
    c_out : int   # Output feature size
    subsample : bool = False  # If True, we apply a stride inside F

    @nn.compact
    def __call__(self, x, train=True):
        # Network representing F
        z = nn.Conv(self.c_out, kernel_size=(3, 3),
                    strides=(1, 1) if not self.subsample else (2, 2),
                    kernel_init=resnet_kernel_init,
                    use_bias=False)(x)
        z = nn.BatchNorm()(z, use_running_average=not train)
        z = self.act_fn(z)
        z = nn.Conv(self.c_out, kernel_size=(3, 3),
                    kernel_init=resnet_kernel_init,
                    use_bias=False)(z)
        z = nn.BatchNorm()(z, use_running_average=not train)

        if self.subsample:
            x = nn.Conv(self.c_out, kernel_size=(1, 1), strides=(2, 2), kernel_init=resnet_kernel_init, use_bias=False)(x)

        x_out = self.act_fn(z + x)
        return x_out

class ResNet(nn.Module):
    num_classes : int
    act_fn : callable
    block_class : nn.Module
    num_blocks : tuple = (3, 3, 3)
    c_hidden : tuple = (16, 32, 64)

    @nn.compact
    def __call__(self, x, train=True):
        # A first convolution on the original image to scale up the channel size
        x = nn.Conv(self.c_hidden[0], kernel_size=(3, 3), kernel_init=resnet_kernel_init, use_bias=False)(x)
        if self.block_class == ResNetBlock:  # If pre-activation block, we do not apply non-linearities yet
            x = nn.BatchNorm()(x, use_running_average=not train)
            x = self.act_fn(x)

        # Creating the ResNet blocks
        for block_idx, block_count in enumerate(self.num_blocks):
            for bc in range(block_count):
                # Subsample the first block of each group, except the very first one.
                subsample = (bc == 0 and block_idx > 0)
                # ResNet block
                x = self.block_class(c_out=self.c_hidden[block_idx],
                                     act_fn=self.act_fn,
                                     subsample=subsample)(x, train=train)

        # Mapping to classification output
        x = x.mean(axis=(1, 2))
        x = x.reshape((x.shape[0], -1))
        x = nn.Dense(self.num_classes, use_bias=False)(x)
        return x

In [None]:
for batch in train_ds.as_numpy_iterator():
    x = batch['image']
    y = batch['label']
    break

resnet11 = ResNet(num_classes=10, act_fn=nn.relu, block_class=ResNetBlock)
print(nn.tabulate(resnet11, jax.random.PRNGKey(42))(x))

In [None]:
from pprint import pprint
variables = resnet11.init(jax.random.PRNGKey(42), x)
pprint(jax.tree_map(jnp.shape, variables), width=100)

---

In [None]:
from functools import partial
resolution = 256

tiled_variables = jax.tree_map(lambda x: jnp.tile(x, (resolution,)+(1,)*len(x.shape)), variables)
jax.tree_map(jnp.shape, tiled_variables)

In [None]:
@jax.jit
@partial(jax.vmap, in_axes=0, out_axes=0)
def f_dict(x):
    return jax.tree_map(lambda x: x*2, x)

res = 10

xdict = {
    'a': jnp.ones((res, 3, 3)),
    'b': jnp.ones((res, 3, 5, 6)),
    'c': jnp.ones((res,))
}

f_dict(xdict)

In [None]:
@jax.jit
def f_dict2(x):
    for k, v in x.items():
        if k == 'a':
            x[k] = v * 0.01
        if k == 'b':
            x[k] = v * 2
        if 'c' in k:
            x[k] = v / 2
    return x

# print(f_dict2.lower(xdict).as_text('stablehlo'))
with jax.checking_leaks():
    txt = f_dict2.lower(xdict).compile().as_text()
print(txt)

In [None]:
pprint(jax.tree_map(jnp.shape, variables)['params'])
# jax.tree_map(jnp.shape, variables)['batch_stats']


---

* v1. resnet

In [None]:
import optax
from tqdm import tqdm

# @jax.jit
def net_test(variables, x: jnp.array):
    params = variables['params']
    batch_stats = variables['batch_stats']
    # T
    x = jnp.transpose(x, [0, 3, 1, 2])
    # 1st conv
    x = jax.lax.conv(x, params['Conv_0']['kernel'], window_strides=(1, 1), padding='SAME')
    x = nn.max_pool(x, window_shape=(3, 3), strides=(1, 1), padding='SAME')
    # ResNetBlocks; conv0-conv1-skip
    for k, v in params.items():
        if 'ResNetBlock' in k:

            residual = x

            # k_conv0 = jnp.transpose(v['Conv_0']['kernel'], [3, 2, 0, 1])
            k_conv0 = v['Conv_0']['kernel']
            x = jax.lax.conv(x, k_conv0, window_strides=(1, 1), padding='SAME')
            x, batch_stats[k]['BatchNorm_0'] = batchnorm(x, v['BatchNorm_0'], batch_stats[k]['BatchNorm_0'], on_train=True)
            x = nn.relu(x)
            
            # k_conv1 = jnp.transpose(v['Conv_1']['kernel'], [3, 2, 0, 1])
            k_conv1 = v['Conv_1']['kernel']
            x = jax.lax.conv(x, k_conv1, window_strides=(1, 1), padding='SAME')
            x, batch_stats[k]['BatchNorm_1'] = batchnorm(x, v['BatchNorm_1'], batch_stats[k]['BatchNorm_1'], on_train=True)
            
            if 'Conv_2' in v.keys():
                # k_conv2 = jnp.transpose(v['Conv_2']['kernel'], [3, 2, 0, 1])
                k_conv2 = v['Conv_2']['kernel']
                residual = jax.lax.conv(residual, k_conv2, window_strides=(1, 1), padding='SAME')
            x += residual
            x = nn.relu(x)

    # FC
    x = nn.avg_pool(x, window_shape=(3, 3), strides=(1, 1), padding='SAME')
    x = jnp.transpose(x, [0, 2, 3, 1])
    x = x.reshape((x.shape[0], -1))
    x = jnp.dot(x, params['Dense_0']['kernel'])

    # batch_stats
    variables = {'params': params, 'batch_stats': batch_stats}

    return nn.softmax(x), variables

def batchnorm(x, params_bn, batch_stats_bn, momentum=0.9, eps=1e-6, on_train=True):
    '''Batch normalizing
        *Args
            params: variables['params']['BatchNorm_X']
            batch_stats: variables['batch_stats']['BatchNorm_X']
    '''
    gamma = params_bn['scale']
    beta = params_bn['bias']
    gamma = gamma.reshape((1, gamma.shape[0], 1, 1))
    beta = beta.reshape((1, beta.shape[0], 1, 1))

    running_mu = batch_stats_bn['mean']
    running_var = batch_stats_bn['var']
    # running_mu = running_mu.reshape((1, running_mu.shape[0], 1, 1))
    # running_var = running_var.reshape((1, running_var.shape[0], 1, 1))

    if on_train == True:
        mu = jnp.mean(x, axis=(0, 2, 3), keepdims=True)
        var = jnp.var(x, axis=(0, 2, 3), keepdims=True)
        running_mu = momentum * running_mu + (1 - momentum) * mu
        running_var = momentum * running_var + (1 - momentum) * var
        x = (x - mu) / jnp.sqrt(var + eps)
    else:
        x = (x - running_mu) / jnp.sqrt(running_var + eps)
    x = gamma * x + beta

    # gamma = gamma.reshape((gamma.shape[1],))
    # beta = beta.reshape((beta.shape[1],))
    # running_mu = running_mu.reshape((running_mu.shape[1],))
    # running_var = running_var.reshape((running_var.shape[1],))

    batch_stats_bn['mean'] = running_mu
    batch_stats_bn['var'] = running_var

    return x, batch_stats_bn
    
# def loss_fn(variables, x, y):
#     logits, variables = net_test(variables, x)
#     return optax.softmax_cross_entropy_with_integer_labels(jnp.clip(logits, 1e-6, 1.), y).mean(), (logits, variables)

# @jax.jit
# def train_oneEpoch(variables, x, y, lr):
#     (loss, (logits, variables)), grads = jax.value_and_grad(loss_fn, has_aux=True)(variables, x, y)  
#     return loss, variables, grads

@jax.jit
def update_fn(variables, x, y, lr):
    def loss_fn(variables, x, y):
        logits, variables = net_test(variables, x)
        return optax.softmax_cross_entropy_with_integer_labels(jnp.clip(logits, 1e-6, 1.), y).mean(), (logits, variables)
    
    (loss, (logits, variables)), grads = jax.value_and_grad(loss_fn, has_aux=True)(variables, x, y)
    variables['params'] = jax.tree_map(lambda param, lr, g: param - lr * g, variables['params'], lr, grads['params'])
    return variables, (loss, logits)
    
def train(variables, batches, lr, epochs):
    loss_archive, logits_archive = [], []
    lr = jax.tree_map(lambda x: jnp.array((lr), dtype=jnp.float32), variables['params'])

    for epoch in tqdm(range(epochs), total=epochs):
        for batch in batches.as_numpy_iterator():
            x = batch['image']
            y = batch['label']
            variables, (loss, logits) = update_fn(variables, x, y, lr)
        loss_archive.append(loss)
        logits_archive.append(logits)
    return loss_archive, logits_archive

def initialize(module, rng, x):
    variables = module.init(jax.random.PRNGKey(rng), x)
    variables['params']['Dense_0']['kernel'] = jax.nn.initializers.xavier_normal()(jax.random.PRNGKey(1), (50176, 10))  # 64 * 28**2
    # variables['params'] = jax.tree_map(lambda param: jnp.transpose(param, (3, 2, 0, 1)), variables['params'])
    def conv_dog(kp, x):
        kp = jax.tree_util.keystr(kp)
        if 'Conv' in kp:
            x = jnp.transpose(x, (3, 2, 0, 1))
        return x
    variables['params'] = jax.tree_util.tree_map_with_path(conv_dog, variables['params'])
    variables['batch_stats'] = jax.tree_map(lambda stats: stats.reshape((1, stats.shape[0], 1, 1)), variables['batch_stats'])
    return variables
        

# variables = resnet11.init(jax.random.PRNGKey(42), x)

# custom_variables = variables.copy()
# custom_variables['params']['Dense_0']['kernel'] = jax.nn.initializers.xavier_normal()(jax.random.PRNGKey(1), (50176, 10))
# custom_variables['batch_stats']['ResNetBlock_0']['BatchNorm_0']['mean'] = custom_variables['batch_stats']['ResNetBlock_0']['BatchNorm_0']['mean'].reshape((1, custom_variables['batch_stats']['ResNetBlock_0']['BatchNorm_0']['mean'].shape[0], 1, 1))
# custom_variables['batch_stats']['ResNetBlock_0']['BatchNorm_0']['var'] = custom_variables['batch_stats']['ResNetBlock_0']['BatchNorm_0']['var'].reshape((1, custom_variables['batch_stats']['ResNetBlock_0']['BatchNorm_0']['var'].shape[0], 1, 1))
variables = initialize(resnet11, 42, x)
loss_archive, logits_archive = train(variables, train_ds, lr=0.01, epochs=10)

* v2. resnet+parallelism

In [None]:
def shard_data(data, n_devices):
    data = data.reshape(n_devices, data.shape[0] // n_devices, *data.shape[1:])
    return data

shard_data(x, 4).shape

In [None]:
import optax
from tqdm import tqdm
from functools import partial

def net_test(variables, x: jnp.array):
    params = variables['params']
    batch_stats = variables['batch_stats']
    
    # input.T
    x = jnp.transpose(x, [0, 3, 1, 2])

    # 1st conv
    x = jax.lax.conv(x, params['Conv_0']['kernel'], window_strides=(1, 1), padding='SAME')
    x = nn.max_pool(x, window_shape=(3, 3), strides=(1, 1), padding='SAME')

    # ResNetBlocks; conv0-conv1-skip
    for k, v in params.items():
        if 'ResNetBlock' in k:

            residual = x

            k_conv0 = v['Conv_0']['kernel']
            x = jax.lax.conv(x, k_conv0, window_strides=(1, 1), padding='SAME')
            x, batch_stats[k]['BatchNorm_0'] = batchnorm(x, v['BatchNorm_0'], batch_stats[k]['BatchNorm_0'], on_train=True)
            x = nn.relu(x)
            
            k_conv1 = v['Conv_1']['kernel']
            x = jax.lax.conv(x, k_conv1, window_strides=(1, 1), padding='SAME')
            x, batch_stats[k]['BatchNorm_1'] = batchnorm(x, v['BatchNorm_1'], batch_stats[k]['BatchNorm_1'], on_train=True)
            
            if 'Conv_2' in v.keys():
                k_conv2 = v['Conv_2']['kernel']
                residual = jax.lax.conv(residual, k_conv2, window_strides=(1, 1), padding='SAME')
            x += residual
            x = nn.relu(x)

    # FC
    x = nn.avg_pool(x, window_shape=(3, 3), strides=(1, 1), padding='SAME')
    x = jnp.transpose(x, [0, 2, 3, 1])
    x = x.reshape((x.shape[0], -1))
    x = jnp.dot(x, params['Dense_0']['kernel'])

    # batch_stats
    variables = {'params': params, 'batch_stats': batch_stats}

    return nn.softmax(x), variables

def batchnorm(x, params_bn, batch_stats_bn, momentum=0.9, eps=1e-6, on_train=True):
    '''Batch normalizing
        *Args
            params: variables['params']['BatchNorm_X']
            batch_stats: variables['batch_stats']['BatchNorm_X']
    '''
    gamma = params_bn['scale']
    beta = params_bn['bias']
    gamma = gamma.reshape((1, gamma.shape[0], 1, 1))
    beta = beta.reshape((1, beta.shape[0], 1, 1))

    running_mu = batch_stats_bn['mean']
    running_var = batch_stats_bn['var']

    if on_train == True:
        mu = jnp.mean(x, axis=(0, 2, 3), keepdims=True)
        var = jnp.var(x, axis=(0, 2, 3), keepdims=True)
        running_mu = momentum * running_mu + (1 - momentum) * mu
        running_var = momentum * running_var + (1 - momentum) * var
        x = (x - mu) / jnp.sqrt(var + eps)
    else:
        x = (x - running_mu) / jnp.sqrt(running_var + eps)
    x = gamma * x + beta

    batch_stats_bn['mean'] = running_mu
    batch_stats_bn['var'] = running_var

    return x, batch_stats_bn
    
# @jax.jit
@partial(jax.pmap, axis_name='batch', in_axes=(None, 0, 0, None), out_axes=(None, 0))
def update_fn(variables, x, y, lr):
    def loss_fn(variables, x, y):
        logits, variables = net_test(variables, x)
        return optax.softmax_cross_entropy_with_integer_labels(jnp.clip(logits, 1e-6, 1.), y).mean(), (logits, variables)
    
    (loss, (logits, variables)), grads = jax.value_and_grad(loss_fn, has_aux=True)(variables, x, y)
    grads = jax.lax.pmean(grads, axis_name='batch')
    loss = jax.lax.pmean(loss, axis_name='batch')
    logits = jax.lax.pmean(logits, axis_name='batch')
    variables['params'] = jax.tree_map(lambda param, lr, g: param - lr * g, variables['params'], lr, grads['params'])
    return variables, (loss, logits)

def train(variables, batches, lr, epochs):
    loss_archive, logits_archive = [], []
    lr = jax.tree_map(lambda x: jnp.array((lr), dtype=jnp.float32), variables['params'])

    for epoch in tqdm(range(epochs), total=epochs):
        for batch in batches.as_numpy_iterator():
            x = shard_data(batch['image'], 4)
            y = shard_data(batch['label'], 4)
            
            variables, (loss, logits) = update_fn(variables, x, y, lr)
        loss_archive.append(loss)
        logits_archive.append(logits)
    return loss_archive, logits_archive

def initialize(module, rng, x):
    variables = module.init(jax.random.PRNGKey(rng), x)
    variables['params']['Dense_0']['kernel'] = jax.nn.initializers.xavier_normal()(jax.random.PRNGKey(1), (50176, 10))  # 64 * 28**2
    # variables['params'] = jax.tree_map(lambda param: jnp.transpose(param, (3, 2, 0, 1)), variables['params'])
    def conv_dog(kp, x):
        kp = jax.tree_util.keystr(kp)
        if 'Conv' in kp:
            x = jnp.transpose(x, (3, 2, 0, 1))
        return x
    variables['params'] = jax.tree_util.tree_map_with_path(conv_dog, variables['params'])
    variables['batch_stats'] = jax.tree_map(lambda stats: stats.reshape((1, stats.shape[0], 1, 1)), variables['batch_stats'])
    return variables
        

variables = initialize(resnet11, 42, x)
loss_archive, logits_archive = train(variables, train_ds, lr=0.01, epochs=2)

In [None]:
import optax
from tqdm import tqdm
from functools import partial

def net_test(variables, x: jnp.array):
    params = variables['params']
    batch_stats = variables['batch_stats']
    
    # input.T
    x = jnp.transpose(x, [0, 3, 1, 2])

    # 1st conv
    x = jax.lax.conv(x, params['Conv_0']['kernel'], window_strides=(1, 1), padding='SAME')
    x = nn.max_pool(x, window_shape=(3, 3), strides=(1, 1), padding='SAME')

    # ResNetBlocks; conv0-conv1-skip
    for k, v in params.items():
        if 'ResNetBlock' in k:

            residual = x

            k_conv0 = v['Conv_0']['kernel']
            x = jax.lax.conv(x, k_conv0, window_strides=(1, 1), padding='SAME')
            x, batch_stats[k]['BatchNorm_0'] = batchnorm(x, v['BatchNorm_0'], batch_stats[k]['BatchNorm_0'], on_train=True)
            x = nn.relu(x)
            
            k_conv1 = v['Conv_1']['kernel']
            x = jax.lax.conv(x, k_conv1, window_strides=(1, 1), padding='SAME')
            x, batch_stats[k]['BatchNorm_1'] = batchnorm(x, v['BatchNorm_1'], batch_stats[k]['BatchNorm_1'], on_train=True)
            
            if 'Conv_2' in v.keys():
                k_conv2 = v['Conv_2']['kernel']
                residual = jax.lax.conv(residual, k_conv2, window_strides=(1, 1), padding='SAME')
            x += residual
            x = nn.relu(x)

    # FC
    x = nn.avg_pool(x, window_shape=(3, 3), strides=(1, 1), padding='SAME')
    x = jnp.transpose(x, [0, 2, 3, 1])
    x = x.reshape((x.shape[0], -1))
    x = jnp.dot(x, params['Dense_0']['kernel'])

    # batch_stats
    variables = {'params': params, 'batch_stats': batch_stats}

    return nn.softmax(x), variables

def batchnorm(x, params_bn, batch_stats_bn, momentum=0.9, eps=1e-6, on_train=True):
    '''Batch normalizing
        *Args
            params: variables['params']['BatchNorm_X']
            batch_stats: variables['batch_stats']['BatchNorm_X']
    '''
    gamma = params_bn['scale']
    beta = params_bn['bias']
    gamma = gamma.reshape((1, gamma.shape[0], 1, 1))
    beta = beta.reshape((1, beta.shape[0], 1, 1))

    running_mu = batch_stats_bn['mean']
    running_var = batch_stats_bn['var']

    if on_train == True:
        mu = jnp.mean(x, axis=(0, 2, 3), keepdims=True)
        var = jnp.var(x, axis=(0, 2, 3), keepdims=True)
        running_mu = momentum * running_mu + (1 - momentum) * mu
        running_var = momentum * running_var + (1 - momentum) * var
        x = (x - mu) / jnp.sqrt(var + eps)
    else:
        x = (x - running_mu) / jnp.sqrt(running_var + eps)
    x = gamma * x + beta

    batch_stats_bn['mean'] = running_mu
    batch_stats_bn['var'] = running_var

    return x, batch_stats_bn
    
# @jax.jit
@partial(jax.pmap, axis_name='batch', in_axes=(None, 0, 0, None), out_axes=(None, 0))
def update_fn(variables, x, y, lr):
    # @jax.jit
    def loss_fn(variables, x, y):
        logits, variables = net_test(variables, x)
        return optax.softmax_cross_entropy_with_integer_labels(jnp.clip(logits, 1e-6, 1.), y).mean(), (logits, variables)
    
    (loss, (logits, variables)), grads = jax.value_and_grad(loss_fn, has_aux=True)(variables, x, y)
    grads = jax.lax.pmean(grads, axis_name='batch')
    # loss = jax.lax.pmean(loss, axis_name='batch')
    # logits = jax.lax.pmean(logits, axis_name='batch')
    variables['params'] = jax.tree_map(lambda param, lr, g: param - lr * g, variables['params'], lr, grads['params'])
    return variables, (loss, logits)

def train(variables, batches, lr, epochs):
    loss_archive, logits_archive = [], []
    lr = jax.tree_map(lambda x: jnp.array((lr), dtype=jnp.float32), variables['params'])

    for epoch in tqdm(range(epochs), total=epochs):
        for batch in batches.as_numpy_iterator():
            x = shard_data(batch['image'], 4)
            y = shard_data(batch['label'], 4)
            
            variables, (loss, logits) = update_fn(variables, x, y, lr)
        loss_archive.append(loss)
        logits_archive.append(logits)
    return loss_archive, logits_archive

def initialize(module, rng, x):
    variables = module.init(jax.random.PRNGKey(rng), x)
    variables['params']['Dense_0']['kernel'] = jax.nn.initializers.xavier_normal()(jax.random.PRNGKey(1), (50176, 10))  # 64 * 28**2
    # variables['params'] = jax.tree_map(lambda param: jnp.transpose(param, (3, 2, 0, 1)), variables['params'])
    def conv_dog(kp, x):
        kp = jax.tree_util.keystr(kp)
        if 'Conv' in kp:
            x = jnp.transpose(x, (3, 2, 0, 1))
        return x
    variables['params'] = jax.tree_util.tree_map_with_path(conv_dog, variables['params'])
    variables['batch_stats'] = jax.tree_map(lambda stats: stats.reshape((1, stats.shape[0], 1, 1)), variables['batch_stats'])
    return variables
        

variables = initialize(resnet11, 42, x)
loss_archive, logits_archive = train(variables, train_ds, lr=0.01, epochs=100)

* v3. parallelism + vmapped theta

In [None]:
import optax
from tqdm import tqdm
from functools import partial

def net_test(variables, x: jnp.array):
    params = variables['params']
    batch_stats = variables['batch_stats']
    
    # input.T
    x = jnp.transpose(x, [0, 3, 1, 2])

    # 1st conv
    x = jax.lax.conv(x, params['Conv_0']['kernel'], window_strides=(1, 1), padding='SAME')
    x = nn.max_pool(x, window_shape=(3, 3), strides=(1, 1), padding='SAME')

    # ResNetBlocks; conv0-conv1-skip
    for k, v in params.items():
        if 'ResNetBlock' in k:

            residual = x

            k_conv0 = v['Conv_0']['kernel']
            x = jax.lax.conv(x, k_conv0, window_strides=(1, 1), padding='SAME')
            x, batch_stats[k]['BatchNorm_0'] = batchnorm(x, v['BatchNorm_0'], batch_stats[k]['BatchNorm_0'], on_train=True)
            x = nn.relu(x)
            
            k_conv1 = v['Conv_1']['kernel']
            x = jax.lax.conv(x, k_conv1, window_strides=(1, 1), padding='SAME')
            x, batch_stats[k]['BatchNorm_1'] = batchnorm(x, v['BatchNorm_1'], batch_stats[k]['BatchNorm_1'], on_train=True)
            
            if 'Conv_2' in v.keys():
                k_conv2 = v['Conv_2']['kernel']
                residual = jax.lax.conv(residual, k_conv2, window_strides=(1, 1), padding='SAME')
            x += residual
            x = nn.relu(x)

    # FC
    x = nn.avg_pool(x, window_shape=(3, 3), strides=(1, 1), padding='SAME')
    x = jnp.transpose(x, [0, 2, 3, 1])
    x = x.reshape((x.shape[0], -1))
    x = jnp.dot(x, params['Dense_0']['kernel'])

    # batch_stats
    variables = {'params': params, 'batch_stats': batch_stats}

    return nn.softmax(x), variables

def batchnorm(x, params_bn, batch_stats_bn, momentum=0.9, eps=1e-6, on_train=True):
    '''Batch normalizing
        *Args
            params: variables['params']['BatchNorm_X']
            batch_stats: variables['batch_stats']['BatchNorm_X']
    '''
    gamma = params_bn['scale']
    beta = params_bn['bias']
    gamma = gamma.reshape((1, gamma.shape[0], 1, 1))
    beta = beta.reshape((1, beta.shape[0], 1, 1))

    running_mu = batch_stats_bn['mean']
    running_var = batch_stats_bn['var']

    if on_train == True:
        mu = jnp.mean(x, axis=(0, 2, 3), keepdims=True)
        var = jnp.var(x, axis=(0, 2, 3), keepdims=True)
        running_mu = momentum * running_mu + (1 - momentum) * mu
        running_var = momentum * running_var + (1 - momentum) * var
        x = (x - mu) / jnp.sqrt(var + eps)
    else:
        x = (x - running_mu) / jnp.sqrt(running_var + eps)
    x = gamma * x + beta

    batch_stats_bn['mean'] = running_mu
    batch_stats_bn['var'] = running_var

    return x, batch_stats_bn
    
# @jax.jit
@partial(jax.pmap, axis_name='batch', in_axes=(None, 0, 0, None), out_axes=(None, 0))
@partial(jax.vmap, in_axes=(0, None, None, None))
def update_fn(variables, x, y, lr):
    @jax.jit
    def loss_fn(variables, x, y):
        logits, variables = net_test(variables, x)
        return optax.softmax_cross_entropy_with_integer_labels(jnp.clip(logits, 1e-6, 1.), y).mean(), (logits, variables)
    
    (loss, (logits, variables)), grads = jax.value_and_grad(loss_fn, has_aux=True)(variables, x, y)
    grads = jax.lax.pmean(grads, axis_name='batch')
    # loss = jax.lax.pmean(loss, axis_name='batch')
    # logits = jax.lax.pmean(logits, axis_name='batch')
    variables['params'] = jax.tree_map(lambda param, lr, g: param - lr * g, variables['params'], lr, grads['params'])
    return variables, (loss, logits)

def train(variables, batches, lr, epochs):
    loss_archive, logits_archive = [], []
    lr = jax.tree_map(lambda x: jnp.array((lr), dtype=jnp.float32), variables['params'])

    for epoch in tqdm(range(epochs), total=epochs):
        # vmapping needed
        for batch in batches.as_numpy_iterator():
            x = shard_data(batch['image'], 4)
            y = shard_data(batch['label'], 4)
            
            variables, (loss, logits) = update_fn(variables, x, y, lr)
        loss_archive.append(loss)
        logits_archive.append(logits)
    return loss_archive, logits_archive

def initialize(module, rng, x):
    variables = module.init(jax.random.PRNGKey(rng), x)
    variables['params']['Dense_0']['kernel'] = jax.nn.initializers.xavier_normal()(jax.random.PRNGKey(1), (50176, 10))  # 64 * 28**2
    # variables['params'] = jax.tree_map(lambda param: jnp.transpose(param, (3, 2, 0, 1)), variables['params'])
    def conv_dog(kp, x):
        kp = jax.tree_util.keystr(kp)
        if 'Conv' in kp:
            x = jnp.transpose(x, (3, 2, 0, 1))
        return x
    variables['params'] = jax.tree_util.tree_map_with_path(conv_dog, variables['params'])
    variables['batch_stats'] = jax.tree_map(lambda stats: stats.reshape((1, stats.shape[0], 1, 1)), variables['batch_stats'])
    return variables
        

resolution = 256

variables = initialize(resnet11, 42, x)
variables = jax.tree_map(lambda x: jnp.tile(x, (resolution,)+(1,)*len(x.shape)), variables)
# loss_archive, logits_archive = train(variables, train_ds, lr=0.01, epochs=100)

In [22]:
variables = initialize(resnet11, 42, x)


In [None]:
# jax.tree_map(lambda x, y, z: x * y + z, 
#              {'a': jnp.array([1,2,3]), 'b': jnp.array([4,5,6])}, 
#              {'a': jnp.array([0.01]), 'b': jnp.array([0.01])}, 
#              {'a': jnp.array([1,2,3]), 'b': jnp.array([4,5,6])})

jax.tree_map(lambda x: jnp.array((0.01)), variables['params'])

In [None]:
# flattened, _ = jax.tree_util.tree_flatten_with_path(variables['params'])
# for kp, v in flattened:
#     print(jax.tree_util.keystr(kp))
    # for k in kp:
    #     k = k.key
    #     if 'Conv' in k:
    #         # print(jax.tree_util.tree_map_with_path(lambda kp, x: x.shape, variables['params']))
    #         print(k)


def conv_dog(kp, x):
    kp = jax.tree_util.keystr(kp)
    if 'Conv' in kp:
        x = jnp.transpose(x, (3, 2, 0, 1))
    return x

vv = jax.tree_util.tree_map_with_path(conv_dog, variables['params'])
vv['Conv_0']['kernel'].shape

In [None]:
import optax

def net(x: jnp.array, variables: dict, training=True) -> jnp.array:
    blocks = variables['params']
    for block in blocks:
        conv = jax.lax.conv()
        bn = batchnorm(..., training=True)    # memory batch_stats in this line
        dot = jnp.dot()
        relu = jax.nn.relu()
    softmax = jax.nn.softmax()
    return softmax

def batchnorm(...):
    bs = variables['batch_stats']
    ...

def loss_fn(variables: dict, x: jnp.array, y: jnp.array, training=True) -> jnp.array:
    x = net(..., training=training)
    return optax.softmax_cross_entropy_with_integer_labels(x, y)

@jax.jit
@jax.vmap(in_axes=(0, 0, None, None))
def train_oneEpoch(variables: dict, lr: jnp.array, x: jnp.array, y: jnp.array):
    loss, grads = jax.value_and_grad(loss_fn)(...)
            te(variables, grads, lr):
        variables['params'] = variables['params'] - lr * grads  # jax.tree_map
        return variables
    trace_on_loss(loss)
    variables = update(variables, grads, lr)
    return variables

@jax.jit
@jax.vmap
def validate_oneEpoch(variables, x, y):
    loss = loss_fn(..., training=False)
    trace_on_loss(loss)
    
def train(num_epochs, variables, lr, x, y):
    for i in range(num_epochs):
        variables = train_oneEpoch(...)
        _ = validate_oneEpoch(...)

    deque_fn()  # ???: class?? or return something??
    return ...



def trace_on_loss(loss):
    # deques of loss and accuracy
    ...






