In [13]:
# From Neel's code
import os
from pathlib import Path
from typing import Callable
import json
import functools
import math


import jax
import jax.numpy as jnp
from jax import random, nn

from optax import adam, rmsprop, sgd

import haiku as hk
from haiku.initializers import Initializer, Constant, RandomNormal, TruncatedNormal, VarianceScaling


def ctc_net_fn(x: jnp.ndarray,
               n_classes: int,
               n_conv_layers: int = 3,
               kernel_size: tuple = (3, 3),
               n_filters: int = 32,
               n_fc_layers: int = 3,
               fc_width: int = 128,
               activation: Callable = nn.relu,
               w_init: Initializer = TruncatedNormal()) -> jnp.ndarray:  # TODO: Batchnorm?
    convs = [hk.Conv2D(output_channels=n_filters, kernel_shape=kernel_size, padding="SAME", w_init=w_init)
             for _ in range(n_conv_layers)]
    fcs = [hk.Linear(fc_width, w_init=w_init) for _ in range(n_fc_layers - 1)]

    seq = []
    for conv in convs:
        seq.append(conv)
        seq.append(activation)
    seq.append(hk.Flatten())
    for fc in fcs:
        seq.append(fc)
        seq.append(activation)
    seq.append(hk.Linear(n_classes, w_init=w_init))

    net = hk.Sequential(seq)
    return net(x)


key = random.PRNGKey(4)

In [68]:
import dataclasses
from meta_transformer import utils
from jax import vmap
from meta_transformer.transformer import Transformer

In [15]:
net = hk.without_apply_rng(hk.transform(ctc_net_fn))

In [65]:
key, subkey = random.split(key)
params = net.init(subkey, jnp.ones((1, 32, 32, 1)), 10)
jax.tree_util.tree_map(lambda x: x.shape, params)

{'conv2_d': {'b': (32,), 'w': (3, 3, 1, 32)},
 'conv2_d_1': {'b': (32,), 'w': (3, 3, 32, 32)},
 'conv2_d_2': {'b': (32,), 'w': (3, 3, 32, 32)},
 'linear': {'b': (128,), 'w': (32768, 128)},
 'linear_1': {'b': (128,), 'w': (128, 128)},
 'linear_2': {'b': (10,), 'w': (128, 10)}}

In [42]:
jax.tree_util.tree_map(lambda x: x.size, params)

{'conv2_d': {'b': 32, 'w': 288},
 'conv2_d_1': {'b': 32, 'w': 9216},
 'conv2_d_2': {'b': 32, 'w': 9216},
 'linear': {'b': 128, 'w': 4194304},
 'linear_1': {'b': 128, 'w': 16384},
 'linear_2': {'b': 10, 'w': 1280}}

In [50]:
utils.count_params(params) / 1e6

4.23105

In [12]:
def chunk_weights(weights: jnp.ndarray, chunk_size: int) -> jnp.ndarray:
    flat_weights = weights.flatten()
    flat_weights = utils.pad_to_chunk_size(flat_weights, chunk_size)
    weight_chunks = jnp.split(flat_weights, len(flat_weights) // chunk_size)
    return jnp.array(weight_chunks)

In [52]:
@dataclasses.dataclass
class WeightEmbedding(hk.Module):
    """A module that embeds an array of neural network weights"""
    chunk_size: int
    embed_dim: int

    def __call__(
        self,
        weights: jax.Array,
    ) -> jax.Array:  # [B, T, D]
        embed = hk.Linear(self.embed_dim)
        weight_chunks = hk.vmap(chunk_weights, (0, None))(weights, self.chunk_size)  # [B, T, D]
        embeddings = embed(weight_chunks)
        return embeddings


@dataclasses.dataclass
class NetEmbedding(hk.Module):
    """A module that creates embedding vectors from neural network params."""
    embed_dim: int

    def __call__(
            self,
            params: dict,
    ) -> jax.Array:
        conv_embed = WeightEmbedding(chunk_size=256, embed_dim=self.embed_dim)
        linear_embed = WeightEmbedding(chunk_size=1024, embed_dim=self.embed_dim)
        bias_embed = WeightEmbedding(chunk_size=16, embed_dim=self.embed_dim)

        params_dict = {f"{k}/{subk}": subv for k, v in params.items() 
                    for subk, subv in v.items()}

        embeddings = []
        for k, v in params_dict.items():
            if k.endswith('b'):
                embeddings.append(bias_embed(v))
            else:
                embeddings.append(conv_embed(v) if 'conv' in k else linear_embed(v))
        return jnp.concatenate(embeddings, axis=1)

In [53]:
def tree_stack(trees):
    """Stacks a list of trees into a single tree with an extra dimension."""
    return jax.tree_map(lambda *x: jnp.stack(x), *trees)

In [66]:
stacked = tree_stack([params, params])
jax.tree_map(lambda x: x.shape, stacked)

{'conv2_d': {'b': (2, 32), 'w': (2, 3, 3, 1, 32)},
 'conv2_d_1': {'b': (2, 32), 'w': (2, 3, 3, 32, 32)},
 'conv2_d_2': {'b': (2, 32), 'w': (2, 3, 3, 32, 32)},
 'linear': {'b': (2, 128), 'w': (2, 32768, 128)},
 'linear_1': {'b': (2, 128), 'w': (2, 128, 128)},
 'linear_2': {'b': (2, 10), 'w': (2, 128, 10)}}

In [55]:
def test_fn(params: dict):
    net_embedding = NetEmbedding(embed_dim=32)
    return net_embedding(params)

test = hk.without_apply_rng(hk.transform(test_fn))
key, subkey = random.split(key)
meta_params = test.init(subkey, stacked)

In [56]:
utils.count_params(meta_params) / 1e6

0.041568

In [60]:
embed = jax.jit(functools.partial(test.apply, meta_params))

In [67]:
param_embeds = embed(stacked)
param_embeds.shape

(2, 4211, 32)

In [69]:
from typing import Optional

In [76]:
@dataclasses.dataclass
class Classifier(hk.Module):
  """A ViT-style classifier."""

  transformer: Transformer
  model_size: int
  num_classes: int
  name: Optional[str] = None
  chunk_size: Optional[int] = 4

  def __call__(
      self,
      params: dict,
      *,
      is_training: bool = True,
  ) -> jax.Array:
    """Forward pass. Returns a sequence of logits."""
    net_embed = NetEmbedding(embed_dim=self.model_size)
    embeddings = net_embed(params)  # [B, T, D]
    _, seq_len, _ = embeddings.shape

    # Embed the patches and positions.
    positional_embeddings = hk.get_parameter(
        'positional_embeddings', [seq_len, self.model_size], init=jnp.zeros)
    input_embeddings = embeddings + positional_embeddings  # [B, T, D]

    # Run the transformer over the inputs.
    embeddings = self.transformer(
        input_embeddings,
        is_training=is_training,
    )  # [B, T, D]

    return hk.Linear(self.num_classes)(embeddings)  # [B, T, V]

In [79]:
def model_fn(params: dict):
    net = Classifier(
        model_size=4*32, 
        num_classes=10, 
        transformer=Transformer(
            num_heads=4,
            num_layers=2,
            key_size=32,
            dropout_rate=0.1,
        ))
    return net(params)


model = hk.transform(model_fn)
    
key, subkey = random.split(key)
meta_params = jax.jit(model.init)(subkey, stacked)
model_forward = jax.jit(model.apply)
key, subkey = random.split(key)
out = model_forward(meta_params, subkey, stacked)

In [80]:
out.shape

(2, 4211, 10)