In [1]:
# From Neel's code
import os
from pathlib import Path
from typing import Callable
import json
import functools
import math


import jax
import jax.numpy as jnp
from jax import random, nn

from optax import adam, rmsprop, sgd

import haiku as hk
from haiku.initializers import Initializer, Constant, RandomNormal, TruncatedNormal, VarianceScaling


def ctc_net_fn(x: jnp.ndarray,
               n_classes: int,
               n_conv_layers: int = 3,
               kernel_size: tuple = (3, 3),
               n_filters: int = 32,
               n_fc_layers: int = 3,
               fc_width: int = 128,
               activation: Callable = nn.relu,
               w_init: Initializer = TruncatedNormal()) -> jnp.ndarray:  # TODO: Batchnorm?
    convs = [hk.Conv2D(output_channels=n_filters, kernel_shape=kernel_size, padding="SAME", w_init=w_init)
             for _ in range(n_conv_layers)]
    fcs = [hk.Linear(fc_width, w_init=w_init) for _ in range(n_fc_layers - 1)]

    seq = []
    for conv in convs:
        seq.append(conv)
        seq.append(activation)
    seq.append(hk.Flatten())
    for fc in fcs:
        seq.append(fc)
        seq.append(activation)
    seq.append(hk.Linear(n_classes, w_init=w_init))

    net = hk.Sequential(seq)
    return net(x)


key = random.PRNGKey(4)

In [2]:
import dataclasses
from meta_transformer import utils
from jax import vmap
from meta_transformer.transformer import Transformer
from meta_transformer.meta_model import MetaModelClassifier

  from .autonotebook import tqdm as notebook_tqdm


In [10]:
net = hk.without_apply_rng(hk.transform(ctc_net_fn))
key, subkey = random.split(key)
params = net.init(subkey, jnp.ones((1, 32, 32, 1)), 10)
jax.tree_util.tree_map(lambda x: x.shape, params)

{'conv2_d': {'b': (32,), 'w': (3, 3, 1, 32)},
 'conv2_d_1': {'b': (32,), 'w': (3, 3, 32, 32)},
 'conv2_d_2': {'b': (32,), 'w': (3, 3, 32, 32)},
 'linear': {'b': (128,), 'w': (32768, 128)},
 'linear_1': {'b': (128,), 'w': (128, 128)},
 'linear_2': {'b': (10,), 'w': (128, 10)}}

In [11]:
jax.tree_util.tree_map(lambda x: x.size, params)

{'conv2_d': {'b': 32, 'w': 288},
 'conv2_d_1': {'b': 32, 'w': 9216},
 'conv2_d_2': {'b': 32, 'w': 9216},
 'linear': {'b': 128, 'w': 4194304},
 'linear_1': {'b': 128, 'w': 16384},
 'linear_2': {'b': 10, 'w': 1280}}

In [6]:
utils.count_params(params) / 1e6

4.23105

In [7]:
stacked = utils.tree_stack([params, params])
jax.tree_map(lambda x: x.shape, stacked)

{'conv2_d': {'b': (2, 32), 'w': (2, 3, 3, 1, 32)},
 'conv2_d_1': {'b': (2, 32), 'w': (2, 3, 3, 32, 32)},
 'conv2_d_2': {'b': (2, 32), 'w': (2, 3, 3, 32, 32)},
 'linear': {'b': (2, 128), 'w': (2, 32768, 128)},
 'linear_1': {'b': (2, 128), 'w': (2, 128, 128)},
 'linear_2': {'b': (2, 10), 'w': (2, 128, 10)}}

In [8]:
def model_fn(params: dict):
    net = MetaModelClassifier(
        model_size=4*32, 
        num_classes=10, 
        transformer=Transformer(
            num_heads=4,
            num_layers=2,
            key_size=32,
            dropout_rate=0.1,
        ))
    return net(params)


model = hk.transform(model_fn)
    
key, subkey = random.split(key)
meta_params = jax.jit(model.init)(subkey, stacked)
model_forward = jax.jit(model.apply)
key, subkey = random.split(key)
out = model_forward(meta_params, subkey, stacked)

In [12]:
out.shape

(2, 4211, 10)