In [1]:
import jax
import flax
from flax import linen as nn
from model.resnet_v3 import *

In [19]:
from datasets.mnist import *

batch_size = 32

train_ds = data_normalize(train_ds).shuffle(buffer_size=10, seed=42).batch(batch_size).prefetch(1).take(10)
test_ds = data_normalize(test_ds).shuffle(buffer_size=10, seed=42).batch(batch_size).prefetch(1).take(10)

total_batch = train_ds.cardinality().numpy()
total_tbatch = test_ds.cardinality().numpy()

for batch in train_ds.as_numpy_iterator():
    x = batch['image']
    y = batch['label']
    break

In [6]:
target_dim = 10
num_blocks = (3, 3, 3)
c_hidden = (16, 32, 64)

resnet = ResNet(num_classes=target_dim, act_fn=nn.relu, block_class=ResNetBlock, num_blocks=num_blocks, c_hidden=c_hidden)
resnet20 = ResNet(10, nn.relu, ResNetBlock)
variables = resnet20.init(jax.random.PRNGKey(1), x)

In [12]:
params = variables['params']
batch_stats = variables['batch_stats']

y, updates = resnet20.apply(
    {'params': params, 'batch_stats': batch_stats},
    x,
    on_train=True,
    mutable=['batch_stats']
)
batch_stats = updates['batch_stats']

In [14]:
from flax.training import train_state
from typing import Any

class TrainState(train_state.TrainState):
  batch_stats: Any

state = TrainState.create(
  apply_fn=resnet20.apply,
  params=params,
  batch_stats=batch_stats,
  tx=optax.adam(1e-3),
)

In [20]:
from tqdm import tqdm


@jax.jit
def train_step(state: TrainState, batch):
  """Train for a single step."""
  def loss_fn(params):
    logits, updates = state.apply_fn(
      {'params': params, 'batch_stats': state.batch_stats},
      x=batch['image'], on_train=True, mutable=['batch_stats'])
    loss = optax.softmax_cross_entropy_with_integer_labels(
      logits=logits, labels=batch['label']).mean()
    return loss, (logits, updates)
  grad_fn = jax.value_and_grad(loss_fn, has_aux=True)
  (loss, (logits, updates)), grads = grad_fn(state.params)
  state = state.apply_gradients(grads=grads)
  state = state.replace(batch_stats=updates['batch_stats'])
  metrics = {
    'loss': loss,
      'accuracy': jnp.mean(jnp.argmax(logits, -1) == batch['label']),
  }
  return state, metrics

for i in tqdm(range(10)):
  for batch in train_ds.as_numpy_iterator():
    state, metrics = train_step(state, batch)

100%|██████████| 10/10 [00:25<00:00,  2.55s/it]


In [28]:
def train_custom(lr, epochs=10):

    state = TrainState.create(
        apply_fn=resnet20.apply,
        params=params,
        batch_stats=batch_stats,
        tx=optax.adam(lr),
        )
    
    for i in tqdm(range(epochs), leave=False, desc='epochs'):
        for batch in train_ds.as_numpy_iterator():
            state, metrics = train_step(state, batch)

    return state, metrics


lrs = jnp.linspace(1e-6, 1, 100)

states, metrics = [], []
for lr in tqdm(lrs, desc='px'):
    state, metric = train_custom(lr, epochs=10)
    states.append(state)
    metrics.append(metric)



px:   1%|          | 1/100 [00:50<1:22:52, 50.23s/it]


KeyboardInterrupt: 

* jax+flax

In [40]:
def net(variables, x: jnp.array, on_train=True):
    params = variables['params']
    batch_stats = variables['batch_stats']
    
    # input.T
    x = jnp.transpose(x, [0, 3, 1, 2])

    # 1st conv
    x = jax.lax.conv(x, params['Conv_0']['kernel'], window_strides=(1, 1), padding='SAME')
    x, batch_stats['BatchNorm_0'] = batchnorm(x, params['BatchNorm_0'], batch_stats['BatchNorm_0'], on_train=on_train)
    x = nn.relu(x)
    x = nn.max_pool(x, window_shape=(3, 3), strides=(1, 1), padding='SAME')

    # ResNetBlocks; conv0-conv1-skip
    for k, v in params.items():
        if 'ResNetBlock' in k:

            residual = x

            k_conv0 = v['Conv_0']['kernel']
            x = jax.lax.conv(x, k_conv0, window_strides=(1, 1), padding='SAME')
            x, batch_stats[k]['BatchNorm_0'] = batchnorm(x, v['BatchNorm_0'], batch_stats[k]['BatchNorm_0'], on_train=on_train)
            x = nn.relu(x)
            
            k_conv1 = v['Conv_1']['kernel']
            x = jax.lax.conv(x, k_conv1, window_strides=(1, 1), padding='SAME')
            x, batch_stats[k]['BatchNorm_1'] = batchnorm(x, v['BatchNorm_1'], batch_stats[k]['BatchNorm_1'], on_train=on_train)
            
            if 'Conv_2' in v.keys():
                k_conv2 = v['Conv_2']['kernel']
                residual = jax.lax.conv(residual, k_conv2, window_strides=(1, 1), padding='SAME')
            x += residual
            x = nn.relu(x)

    # FC
    x = nn.avg_pool(x, window_shape=(3, 3), strides=(1, 1), padding='SAME')
    x = jnp.transpose(x, [0, 2, 3, 1])
    x = x.reshape((x.shape[0], -1))
    x = jnp.dot(x, params['Dense_0']['kernel'])

    # batch_stats
    variables = {'params': params, 'batch_stats': batch_stats}

    return nn.softmax(x), variables

def batchnorm(x, params_bn, batch_stats_bn, momentum=0.9, eps=1e-6, on_train=True):
    '''Batch normalizing
        *Args
            params: variables['params']['BatchNorm_X']
            batch_stats: variables['batch_stats']['BatchNorm_X']
    '''
    gamma = params_bn['scale']
    beta = params_bn['bias']
    gamma = gamma.reshape((1, gamma.shape[0], 1, 1))
    beta = beta.reshape((1, beta.shape[0], 1, 1))

    running_mu = batch_stats_bn['mean']
    running_var = batch_stats_bn['var']
    
    def mode_train():
        mu = jnp.mean(x, axis=(0, 2, 3), keepdims=True)
        var = jnp.var(x, axis=(0, 2, 3), keepdims=True)
        r_mu = momentum * running_mu + (1 - momentum) * mu
        r_var = momentum * running_var + (1 - momentum) * var
        return (x - mu) / jnp.sqrt(var + eps), r_mu, r_var
    
    def mode_inference():
        r_mu = running_mu
        r_var = running_var
        return (x - r_mu) / jnp.sqrt(r_var + eps), r_mu, r_var
        
    x, running_mu, running_var = jax.lax.cond(on_train, mode_train, mode_inference)
    
    x = gamma * x + beta

    batch_stats_bn['mean'] = running_mu
    batch_stats_bn['var'] = running_var

    return x, batch_stats_bn


@partial(jax.jit, static_argnums=3)
def loss_fn(variables, x, y, on_train=True):
    logits, variables = net(variables, x, on_train=on_train)
    return optax.softmax_cross_entropy_with_integer_labels(jnp.clip(logits, 1e-10, 1.), y).mean(), (logits, variables)

@partial(jax.vmap, in_axes=(0, None, None, 0))
def update_fn(variables, x, y, lr):
    (loss, (logits, variables)), grads = jax.value_and_grad(loss_fn, has_aux=True)(variables, x, y)
    variables['params'] = jax.tree_map(lambda param, lr, g: param - lr * g, variables['params'], lr, grads['params'])
    return variables, (loss, logits)

def train_custom3(lr, epochs=10):
    
    for i in tqdm(range(epochs), leave=False, desc='epochs'):
        for batch in train_ds.as_numpy_iterator():
            state, metrics = train_step(state, batch)

    return state, metrics
