# Setup

In [None]:
!pip install git+https://github.com/deepmind/dm-haiku
!pip install transformers
!pip install clu
!pip install wandb
!pip install optax
!pip install flatten-dict

Looking in indexes: https://pypi.org/simple, https://us-python.pkg.dev/colab-wheels/public/simple/
Collecting git+https://github.com/deepmind/dm-haiku
  Cloning https://github.com/deepmind/dm-haiku to /tmp/pip-req-build-a1qsj3fk
  Running command git clone -q https://github.com/deepmind/dm-haiku /tmp/pip-req-build-a1qsj3fk
Collecting jmp>=0.0.2
  Downloading jmp-0.0.2-py3-none-any.whl (16 kB)
Building wheels for collected packages: dm-haiku
  Building wheel for dm-haiku (setup.py) ... [?25l[?25hdone
  Created wheel for dm-haiku: filename=dm_haiku-0.0.8.dev0-py3-none-any.whl size=607970 sha256=9978cd15953d982a8b2a76d01741ac16cdecdf7fbae281c62c6d787513b35ce3
  Stored in directory: /tmp/pip-ephem-wheel-cache-pn1w3wo5/wheels/06/28/69/ebaac5b2435641427299f29d88d005fb4e2627f4a108f0bdbc
Successfully built dm-haiku
Installing collected packages: jmp, dm-haiku
Successfully installed dm-haiku-0.0.8.dev0 jmp-0.0.2
Looking in indexes: https://pypi.org/simple, https://us-python.pkg.dev/colab-whee

In [None]:
# sys
import os, sys
from io import BytesIO
# helper
import re
import joblib
import requests
from functools import lru_cache
import flatten_dict
#ML CORE
import jax
from jax.random import PRNGKey
from jax import jit
import jax.numpy as jnp
import numpy as np
import optax
import haiku as hk
import torch
import tensorflow as tf
import tensorflow_datasets as tfds
# ML helper
from clu import parameter_overview
from transformers import RobertaTokenizer, RobertaModel,FlaxRobertaModel
# logging
import wandb

# Moduels

## Embedding

In [None]:
class Embedding(hk.Module):
    """
    Embeds tokens and positions into an array of shape [n_batch, n_seq, n_hidden]
    """
    def __init__(self, config,name=None):
        super().__init__(name=name)
        self.config = config

    def __call__(self, token_ids, training=False):
        """
        token_ids: ints of shape (batch, n_seq)
        """
        
        word_embeddings = self.config['pretrained']['embeddings/word_embeddings/embedding']
        
        # We have to flatten our tokens before passing them to the hk.Embed module,
        # as arrays with more than one dimension are interpreted as multi-dimensional indexes
        flat_token_ids = jnp.reshape(token_ids, [token_ids.shape[0] * token_ids.shape[1]])
        flat_token_embeddings = hk.Embed(
            vocab_size=word_embeddings.shape[0],
            embed_dim=word_embeddings.shape[1],
            
            # Here we're using hk.initializers.Constant to supply pre-trained embeddings
            # to our hk.Embed module
            w_init=hk.initializers.Constant(word_embeddings)
        )(flat_token_ids)
        
        # After we've embedded our token IDs, we reshape to recover our batch dimension
        token_embeddings = jnp.reshape(
            flat_token_embeddings, 
            [token_ids.shape[0], token_ids.shape[1], word_embeddings.shape[1]]
        )
        
        # Combine our token embeddings with a set of learned positional embeddings
        embeddings = token_embeddings + PositionEmbeddings(self.config)()
        embeddings = hk.LayerNorm(
            axis=-1, 
            create_scale=True,
            create_offset=True,
            
            # The layer norm parameters are also pretrained, so we have to take care to 
            # use a constant initializer for these as well
            scale_init=hk.initializers.Constant(
                self.config['pretrained']['embeddings/LayerNorm/scale']
            ),
            offset_init=hk.initializers.Constant(
                self.config['pretrained']['embeddings/LayerNorm/bias']
            )
        )(embeddings)
        
        # Dropout will be applied later when we finetune our Roberta implementation 
        # to solve a classification task. For now we'll set `training` to False.
        if training:
            embeddings = hk.dropout(
                # Haiku magic -- we'll explicitly provide a RNG key to haiku later to make this function
                hk.next_rng_key(), 
                rate=self.config['embed_dropout_rate'], 
                x=embeddings
            )
        
        return embeddings

## positional embdedding

In [None]:
class PositionEmbeddings(hk.Module):
    """
    A position embedding of shape [n_seq, n_hidden]
    """
    def __init__(self, config):
        super().__init__()
        self.config = config
        # The Roberta position embeddings are offset by 2
        self.offset = 2

    def __call__(self):
        pretrained_position_embedding = self.config['pretrained']['embeddings/position_embeddings/embedding']
        position_weights = hk.get_parameter(
            "position_embeddings", 
            pretrained_position_embedding.shape,
            init=hk.initializers.Constant(pretrained_position_embedding)
        )
        
        return position_weights[self.offset:self.offset + self.config['max_length']]

## Multi Head Attention

In [None]:
class MultiHeadAttention(hk.Module):

    def __init__(self, config, layer_num):
        super().__init__()
        self.config = config
        self.n = layer_num

    def _split_into_heads(self, x):
        return jnp.reshape(
            x, 
            [
                x.shape[0],
                x.shape[1],
                self.config['n_heads'],
                x.shape[2] // self.config['n_heads']
            ]
        )

    def __call__(self, x, mask, training=False):
        """
        x: tensor of shape (batch, seq, n_hidden)
        mask: tensor of shape (batch, seq)
        """
        scope = Scope(self.config['pretrained'], f'encoder/layer/{self.n}/attention/')
       
        # Project to queries, keys, and values
        # Shapes are all [batch, sequence_length, hidden_size]
        queries = hk.Linear(
            output_size=self.config['hidden_size'],
            w_init=hk.initializers.Constant(scope['self/query/kernel']),
            b_init=hk.initializers.Constant(scope['self/query/bias'])
        )(x)
        keys = hk.Linear(
            output_size=self.config['hidden_size'],
            w_init=hk.initializers.Constant(scope['self/key/kernel']),
            b_init=hk.initializers.Constant(scope['self/key/bias'])
        )(x)
        values = hk.Linear(
            output_size=self.config['hidden_size'],
            w_init=hk.initializers.Constant(scope['self/value/kernel']),
            b_init=hk.initializers.Constant(scope['self/value/bias'])
        )(x)
        
        # Reshape our hidden state to group into heads
        # New shape are [batch, sequence_length, n_heads, size_per_head]
        queries = self._split_into_heads(queries)
        keys = self._split_into_heads(keys)
        values = self._split_into_heads(values)
        

        # Compute per head attention weights 
        # b: batch
        # s: source sequence
        # t: target sequence
        # n: number of heads
        # h: per-head hidden state
        
        # Note -- we could also write this with jnp.reshape and jnp.matmul, but I'm becoming
        # a fan of how concise opting to use einsum notation for this kind of operation is.
        # For more info, see https://rockt.github.io/2018/04/30/einsum and 
        # https://obilaniu6266h16.wordpress.com/2016/02/04/einstein-summation-in-numpy/
        attention_logits = jnp.einsum('bsnh,btnh->bnst', queries, keys) / np.sqrt(queries.shape[-1])
        # Add logits of mask tokens with a large negative number to prevent attending to those terms.
        attention_logits += jnp.reshape(mask * -2.0**32, [mask.shape[0], 1, 1, mask.shape[1]])
        attention_weights = jax.nn.softmax(attention_logits, axis=-1)
        per_head_attention_output = jnp.einsum('btnh,bnst->bsnh', values, attention_weights)
        attention_output = jnp.reshape(
            per_head_attention_output, 
            [
                per_head_attention_output.shape[0],
                per_head_attention_output.shape[1],
                per_head_attention_output.shape[2] * per_head_attention_output.shape[3]
            ]
        )

        # Apply dense layer to output of attention operation
        attention_output = hk.Linear(
            output_size=self.config['hidden_size'],
            w_init=hk.initializers.Constant(scope['output/dense/kernel']),
            b_init=hk.initializers.Constant(scope['output/dense/bias'])
        )(attention_output)

        # Apply dropout at training time
        if training:
            attention_output = hk.dropout(
                rng=hk.next_rng_key(),
                rate=self.config['attention_drop_rate'],
                x=attention_output
            )

        return attention_output

## Transformer MLP

In [None]:
def gelu(x):
    """
    We use this in place of jax.nn.relu because the approximation used 
    produces a non-trivial difference in the output state
    """
    return x * 0.5 * (1.0 + jax.scipy.special.erf(x / jnp.sqrt(2.0)))
class TransformerMLP(hk.Module):

    def __init__(self, config, layer_num):
        super().__init__()
        self.config = config
        self.n = layer_num

    def __call__(self, x, training=False):
        # Project out to higher dim
        scope = Scope(self.config['pretrained'], f'encoder/layer/{self.n}/')
        intermediate_output = hk.Linear(
            output_size=self.config['intermediate_size'],
            w_init=hk.initializers.Constant(scope['intermediate/dense/kernel']),
            b_init=hk.initializers.Constant(scope['intermediate/dense/bias'])
        )(x)

        # Apply gelu nonlinearity
        intermediate_output = gelu(intermediate_output)

        # Project back down to hidden size
        output = hk.Linear(
            output_size=self.config['hidden_size'],
            w_init=hk.initializers.Constant(scope['output/dense/kernel']),
            b_init=hk.initializers.Constant(scope['output/dense/bias']),
        )(intermediate_output)

        # Apply dropout at training time
        if training:
            output = hk.dropout(
                rng=hk.next_rng_key(), 
                rate=self.config['fully_connected_drop_rate'],
                x=output
            )

        return output
        

## Transformer Block

In [None]:
class TransformerBlock(hk.Module):

    def __init__(self, config, layer_num):
        super().__init__()
        self.config = config
        self.n = layer_num

    def __call__(self, x, mask, training=False):
        scope = Scope(self.config['pretrained'], f'encoder/layer/{self.n}/')
        # Feed our input through a multi-head attention operation
        attention_output = MultiHeadAttention(self.config, self.n)(x, mask, training=training)

        # Add a residual connection with the input to the layer
        residual = attention_output + x

        # Apply layer norm to the combined output
        attention_output = hk.LayerNorm(
            axis=-1,
            create_scale=True,
            create_offset=True,
            scale_init=hk.initializers.Constant(scope['attention/output/LayerNorm/scale']),
            offset_init=hk.initializers.Constant(scope['attention/output/LayerNorm/bias']),
        )(residual)

        # Project out to a larger dim, apply a gelu, and then project back down to our hidden dim
        mlp_output = TransformerMLP(self.config, self.n)(attention_output, training=training)

        # Residual connection to the output of the attention operation
        output_residual = mlp_output + attention_output

        # Apply another LayerNorm
        layer_output = hk.LayerNorm(
            axis=-1,
            create_scale=True,
            create_offset=True,
            scale_init=hk.initializers.Constant(scope['output/LayerNorm/scale']),
            offset_init=hk.initializers.Constant(scope['output/LayerNorm/bias']),
        )(output_residual) 
        return layer_output

## Transformer & ruberta featureiser (same)

In [None]:
class Transformer(hk.Module):
    def __init__(self, config, *args, **kwargs):
        super().__init__(name="Transformer")
        self.config = config
    
    def __call__(self, token_ids):
        x = Embedding(self.config)(token_ids)
        for layer_num, layer in enumerate(range(self.config.n_layers)):
            x = TransformerBlock(self.config, layer_num=layer_num)(x)
        return x


In [None]:
# Same as above
class RobertaFeaturizer(hk.Module):
    def __init__(self, config, *args, **kwargs):
        super().__init__(name="Transformer")
        self.config = config
    
    def __call__(self, token_ids, training=False):
        x = Embedding(self.config)(token_ids, training=training)
        mask = (token_ids == self.config['mask_id']).astype(jnp.float32)
        for layer_num, layer in enumerate(range(self.config['n_layers'])):
            x = TransformerBlock(self.config, layer_num=layer_num)(x, mask, training=training)
        return x


## Ruberta classifier*

In [None]:
class RobertaClassifier(hk.Module):
    def __init__(self, config, *args, **kwargs):
        super().__init__(name="Transformer")
        self.config = config
        
    def __call__(self, token_ids, training=False):
        sequence_features = RobertaFeaturizer(self.config)(token_ids=token_ids, training=training)
        
        # Our classifier representation is just the output state of our first token
        clf_state = sequence_features[:,0,:]
        
        if training:
            clf_state = hk.dropout(
                rng=hk.next_rng_key(),
                rate=self.config['classifier_drop_rate'],
                x=clf_state
            )
        
        # We project down from our hidden dimension to n_classes and use this as our softmax logits
        x = hk.Linear(
            output_size=512,
            w_init=hk.initializers.TruncatedNormal(self.config['weight_stddev']),
            name='logits_linear_0'
        )(clf_state)

        m = hk.Linear(
            output_size=512,
            w_init=hk.initializers.TruncatedNormal(self.config['weight_stddev']),
            name='logits_linear_1'
        )(x)
        x = m + x
        
        m = hk.Linear(
            output_size=512,
            w_init=hk.initializers.TruncatedNormal(self.config['weight_stddev']),
            name='logits_linear_2'
        )(x)
        x = m + x

        m = hk.Linear(
            output_size=512,
            w_init=hk.initializers.TruncatedNormal(self.config['weight_stddev']),
            name='logits_linear_3'
        )(x)
        x = m + x
        
        m = hk.Linear(
            output_size=512,
            w_init=hk.initializers.TruncatedNormal(self.config['weight_stddev']),
            name='logits_linear_4'
        )(x)
        x = m + x

        m = hk.Linear(
            output_size=512,
            w_init=hk.initializers.TruncatedNormal(self.config['weight_stddev']),
            name='logits_linear_5'
        )(x)
        x = m + x
        
        m = hk.Linear(
            output_size=512,
            w_init=hk.initializers.TruncatedNormal(self.config['weight_stddev']),
            name='logits_linear_6'
        )(x)
        x = m + x

        clf_logits = hk.Linear(
            output_size=self.config['n_classes'],
            w_init=hk.initializers.TruncatedNormal(self.config['weight_stddev']),
            name='logits_linear_7'
        )(x)
        
        return clf_logits


# Helper functions

## Get pretrained weights

In [None]:
# We'll make use of these again later as a means to check our implementation
model_name = 'distilroberta-base'
huggingface_roberta_torch = RobertaModel.from_pretrained(
    model_name, 
    output_hidden_states=True,
    force_download= True,
    )
huggingface_roberta = FlaxRobertaModel.from_pretrained(
      model_name,
      output_hidden_states=True,
      force_download= True,
      from_pt=True
    )
huggingface_tokenizer = RobertaTokenizer.from_pretrained(
    'roberta-base', 
    force_download=True
    )

Downloading config.json:   0%|          | 0.00/480 [00:00<?, ?B/s]

Downloading pytorch_model.bin:   0%|          | 0.00/316M [00:00<?, ?B/s]

Some weights of the model checkpoint at distilroberta-base were not used when initializing RobertaModel: ['lm_head.decoder.weight', 'lm_head.layer_norm.weight', 'lm_head.layer_norm.bias', 'lm_head.dense.weight', 'lm_head.bias', 'lm_head.dense.bias']
- This IS expected if you are initializing RobertaModel from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).
- This IS NOT expected if you are initializing RobertaModel from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).


Downloading config.json:   0%|          | 0.00/480 [00:00<?, ?B/s]

Downloading pytorch_model.bin:   0%|          | 0.00/316M [00:00<?, ?B/s]

Some weights of the model checkpoint at distilroberta-base were not used when initializing FlaxRobertaModel: {('lm_head', 'dense', 'bias'), ('lm_head', 'bias'), ('lm_head', 'decoder', 'kernel'), ('lm_head', 'layer_norm', 'kernel'), ('lm_head', 'layer_norm', 'bias'), ('lm_head', 'dense', 'kernel')}
- This IS expected if you are initializing FlaxRobertaModel from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).
- This IS NOT expected if you are initializing FlaxRobertaModel from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).


Downloading vocab.json:   0%|          | 0.00/878k [00:00<?, ?B/s]

Downloading merges.txt:   0%|          | 0.00/446k [00:00<?, ?B/s]

Downloading config.json:   0%|          | 0.00/481 [00:00<?, ?B/s]

In [None]:
huggingface_roberta_torch.config

RobertaConfig {
  "_name_or_path": "distilroberta-base",
  "architectures": [
    "RobertaForMaskedLM"
  ],
  "attention_probs_dropout_prob": 0.1,
  "bos_token_id": 0,
  "classifier_dropout": null,
  "eos_token_id": 2,
  "hidden_act": "gelu",
  "hidden_dropout_prob": 0.1,
  "hidden_size": 768,
  "initializer_range": 0.02,
  "intermediate_size": 3072,
  "layer_norm_eps": 1e-05,
  "max_position_embeddings": 514,
  "model_type": "roberta",
  "num_attention_heads": 12,
  "num_hidden_layers": 6,
  "output_hidden_states": true,
  "pad_token_id": 1,
  "position_embedding_type": "absolute",
  "transformers_version": "4.21.0",
  "type_vocab_size": 1,
  "use_cache": true,
  "vocab_size": 50265
}

## Fetch pretrained weights**

In [None]:
def flatten(d):
    out = {}
    for key, val in d.items():
        if isinstance(val, dict):
            val = [val]
        if isinstance(val, list):
            for subdict in val:
                deeper = flatten(subdict).items()
                out.update({key + '_' + key2: val2 for key2, val2 in deeper})
        else:
            out[key] = val
    return out

In [None]:
def flatten(d):
    out = {}
    for key, val in d.items():
        if isinstance(val, dict):
            val = [val]
        if isinstance(val, list):
            for subdict in val:
                deeper = flatten(subdict).items()
                out.update({key + '/' + key2: val2 for key2, val2 in deeper})
        else:
            out[key] = val
    return out

In [None]:
pretrained = huggingface_roberta._params
pretrained = flatten(pretrained)
#keys = [k for k in pretrained.keys() if 'embedding' in k]
#for key in keys:
#  print(key,'\n')


dict_keys(['embeddings/word_embeddings/embedding', 'embeddings/position_embeddings/embedding', 'embeddings/token_type_embeddings/embedding', 'embeddings/LayerNorm/scale', 'embeddings/LayerNorm/bias', 'encoder/layer/0/attention/self/query/kernel', 'encoder/layer/0/attention/self/query/bias', 'encoder/layer/0/attention/self/key/kernel', 'encoder/layer/0/attention/self/key/bias', 'encoder/layer/0/attention/self/value/kernel', 'encoder/layer/0/attention/self/value/bias', 'encoder/layer/0/attention/output/dense/kernel', 'encoder/layer/0/attention/output/dense/bias', 'encoder/layer/0/attention/output/LayerNorm/scale', 'encoder/layer/0/attention/output/LayerNorm/bias', 'encoder/layer/0/intermediate/dense/kernel', 'encoder/layer/0/intermediate/dense/bias', 'encoder/layer/0/output/dense/kernel', 'encoder/layer/0/output/dense/bias', 'encoder/layer/0/output/LayerNorm/scale', 'encoder/layer/0/output/LayerNorm/bias', 'encoder/layer/1/attention/self/query/kernel', 'encoder/layer/1/attention/self/que

In [None]:
config = {
    'pretrained': pretrained,
    'max_length': 512,
    'embed_dropout_rate': 0.1,
    'fully_connected_drop_rate': 0.1,
    'attention_drop_rate': 0.1,
    'hidden_size': 768,
    'intermediate_size': 3072,
    'n_heads': 12,
    'n_layers': 6,
    'mask_id': 1,
    'weight_stddev': 0.02,

    # For use later in finetuning
    'n_classes': 2,
    'classifier_drop_rate': 0.1,
    'learning_rate': 0.00001,
    'max_grad_norm': 1.0,
    'l2': 0.1,
    'n_epochs': 5,
    'batch_size': 32
}

In [None]:
#print(parameter_overview.get_parameter_overview(huggingface_roberta._params))

## Scope function: systematic weight dictionary lookup

In [None]:
class Scope(object):
    """
    A tiny utility to help make looking up into our dictionary cleaner.
    There's no haiku magic here.
    """
    def __init__(self, weights, prefix):
        self.weights = weights
        self.prefix = prefix

    def __getitem__(self, key):
        lookup = self.prefix + key
        return self.weights[lookup]


# Tests

## Test tokenizer

In [None]:
#
sample_text = "This was a lot less painful than re-implementing"
encoded = huggingface_tokenizer.batch_encode_plus(
    [sample_text, sample_text],
    padding='max_length',
    max_length=config['max_length'],
    truncation=True
)
sample_tokens = encoded['input_ids']
print(sample_tokens[0][:50])

[0, 713, 21, 10, 319, 540, 8661, 87, 769, 12, 757, 40224, 154, 2, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1]


## Construct transformer

In [None]:
# We'll fill our our config later
def features(tokens):
    transformer = Transformer(config)
    return transformer(tokens)
features_fn = hk.transform(features)

## construct and test embedding moduel

In [None]:
def embed_fn(tokens, training=False):
    embedding = Embedding(config)(tokens)
    return embedding
rng = PRNGKey(42)
embed = hk.transform(embed_fn, apply_rng=True)
sample_tokens = np.asarray(sample_tokens)
params = embed.init(rng, sample_tokens, training=False)
embedded_tokens = jit(embed.apply)(params, rng, sample_tokens)
print(embedded_tokens.shape)
print({key: type(value) for key, value in params.items()})
print(jax.tree_map(lambda x: x.shape, params))
print({key: type(value) for key, value in params['embedding/layer_norm'].items()})
print(parameter_overview.get_parameter_overview(params))

(2, 512, 768)
{'embedding/embed': <class 'dict'>, 'embedding/position_embeddings': <class 'dict'>, 'embedding/layer_norm': <class 'dict'>}
{'embedding/embed': {'embeddings': (50265, 768)}, 'embedding/layer_norm': {'offset': (768,), 'scale': (768,)}, 'embedding/position_embeddings': {'position_embeddings': (514, 768)}}
{'scale': <class 'jaxlib.xla_extension.DeviceArray'>, 'offset': <class 'jaxlib.xla_extension.DeviceArray'>}
+---------------------------------------------------+--------------+------------+----------+--------+
| Name                                              | Shape        | Size       | Mean     | Std    |
+---------------------------------------------------+--------------+------------+----------+--------+
| embedding/embed/embeddings                        | (50265, 768) | 38,603,520 | -0.0113  | 0.13   |
| embedding/layer_norm/offset                       | (768,)       | 768        | -0.00171 | 0.0723 |
| embedding/layer_norm/scale                        | (768,)  

## Construct and test featurizer 

In [None]:
def featurizer_fn(tokens, training=False):
    contextual_embeddings = RobertaFeaturizer(config)(tokens, training=training)
    return contextual_embeddings
rng = PRNGKey(42)
roberta = hk.transform(featurizer_fn, apply_rng=True)
sample_tokens = np.asarray(sample_tokens)
params = roberta.init(rng, sample_tokens, training=False)
contextual_embedding = jit(roberta.apply)(params, rng, sample_tokens)
print(parameter_overview.get_parameter_overview(params))
#print(contextual_embedding.shape)


+-----------------------------------------------------------------+--------------+------------+-----------+----------+
| Name                                                            | Shape        | Size       | Mean      | Std      |
+-----------------------------------------------------------------+--------------+------------+-----------+----------+
| Transformer/embedding/embed/embeddings                          | (50265, 768) | 38,603,520 | -0.0113   | 0.13     |
| Transformer/embedding/layer_norm/offset                         | (768,)       | 768        | -0.00171  | 0.0723   |
| Transformer/embedding/layer_norm/scale                          | (768,)       | 768        | 0.349     | 0.0619   |
| Transformer/embedding/position_embeddings/position_embeddings   | (514, 768)   | 394,752    | 0.000403  | 0.0663   |
| Transformer/transformer_block/layer_norm/offset                 | (768,)       | 768        | 0.102     | 0.181    |
| Transformer/transformer_block/layer_norm/scale

In [None]:
[k for k in pretrained.keys() if 'layer/0/' in k] 

['encoder/layer/0/attention/self/query/kernel',
 'encoder/layer/0/attention/self/query/bias',
 'encoder/layer/0/attention/self/key/kernel',
 'encoder/layer/0/attention/self/key/bias',
 'encoder/layer/0/attention/self/value/kernel',
 'encoder/layer/0/attention/self/value/bias',
 'encoder/layer/0/attention/output/dense/kernel',
 'encoder/layer/0/attention/output/dense/bias',
 'encoder/layer/0/attention/output/LayerNorm/scale',
 'encoder/layer/0/attention/output/LayerNorm/bias',
 'encoder/layer/0/intermediate/dense/kernel',
 'encoder/layer/0/intermediate/dense/bias',
 'encoder/layer/0/output/dense/kernel',
 'encoder/layer/0/output/dense/bias',
 'encoder/layer/0/output/LayerNorm/scale',
 'encoder/layer/0/output/LayerNorm/bias']

## Compare model outputs: Roberta-huggingface-torch vs ours

In [None]:
#import torch
#batch_token_ids = torch.tensor(huggingface_tokenizer.encode(sample_text)).unsqueeze(0)
#huggingface_output_state, huggingface_pooled_state, _ = huggingface_roberta_torch.forward(batch_token_ids)[:]
#print(np.allclose(
#    huggingface_output_state.detach().numpy(), 
#    contextual_embedding[:1, :batch_token_ids.size()[1]], 
#    atol=1e-3
#))

# Data and Training

## Load dataset

In [None]:
def txt_transform(text):
  input = list(text)
  output = []
  for t in input:
    t = str(t)
    t = re.sub('[^A-Za-z0-9]+', ' ', t).lower()
    t = t.split(' ')
    t = [a for a in t if len(a) > 3]
    t = ' '.join(t)
    output.append(t)
  return np.array(output)
def load_dataset(split, training, batch_size, n_examples=None):
    """Loads the dataset as a generator of batches."""
    ds = tfds.load(
        "imdb_reviews", split=f"{split}[:{n_examples}]"
        ).cache()
    if training:
        ds = ds.shuffle(10 * batch_size, seed=0)
    ds = ds.batch(batch_size)
    return tfds.as_numpy(ds)

n_examples = 25000
train = load_dataset(
    "train", 
    training=True, 
    batch_size=config['batch_size'], 
    n_examples=n_examples
    )

## Initialise roberta classifier

In [None]:
def roberta_classification_fn(batch_token_ids, training):
    model = RobertaClassifier(config)(
        jnp.asarray(batch_token_ids), 
        training=training
    )
    return model


def encode_batch(batch_text):
    # Accept either utf-8 encoded bytes or unicode
    batch_text = [
        text.decode('utf-8') if isinstance(text, bytes) else text 
        for text in batch_text
    ]
    
    # Use huggingface's tokenizer to convert from raw text to integer token ids
    token_ids = huggingface_tokenizer.batch_encode_plus(
        batch_text, 
        pad_to_max_length=True, 
        max_length=config['max_length'],
    )['input_ids']
    return np.asarray(token_ids)


# Purify our RobertaClassifier through the use of hk.transform and initialize our classifier
rng = jax.random.PRNGKey(42)
roberta_classifier = hk.transform(roberta_classification_fn, apply_rng=True)
params = roberta_classifier.init(
    rng, 
    batch_token_ids=encode_batch(['Sample text', 'Sample text']), 
    training=True
)
print(parameter_overview.get_parameter_overview(params))

Truncation was not explicitly activated but `max_length` is provided a specific value, please use `truncation=True` to explicitly truncate examples to max length. Defaulting to 'longest_first' truncation strategy. If you encode pairs of sequences (GLUE-style) with the tokenizer you can select this strategy more precisely by providing a specific strategy to `truncation`.


+-----------------------------------------------------------------------------+--------------+------------+-----------+----------+
| Name                                                                        | Shape        | Size       | Mean      | Std      |
+-----------------------------------------------------------------------------+--------------+------------+-----------+----------+
| Transformer/Transformer/embedding/embed/embeddings                          | (50265, 768) | 38,603,520 | -0.0113   | 0.13     |
| Transformer/Transformer/embedding/layer_norm/offset                         | (768,)       | 768        | -0.00171  | 0.0723   |
| Transformer/Transformer/embedding/layer_norm/scale                          | (768,)       | 768        | 0.349     | 0.0619   |
| Transformer/Transformer/embedding/position_embeddings/position_embeddings   | (514, 768)   | 394,752    | 0.000403  | 0.0663   |
| Transformer/Transformer/transformer_block/layer_norm/offset                 | (76

In [None]:
def FineTuningDepth(params,depth):
  base_mask = jax.tree_util.tree_map(lambda x: False, params)
  keys = params.keys()
  first = [k for k in keys if 'embedding' in k]
  second = [k for k in keys if 'transformer_block' in k]
  third = [k for k in keys if 'logits_linear' in k]
  n_layers = len(first) + len(second) + len(third)
  assert n_layers == len(keys)
  assert depth <= n_layers
  ordered_keys = [*first,*second,*third]
  for num, layer in enumerate(reversed(ordered_keys)):
    num += 1
    if num <= depth or depth == -1:
      change = jax.tree_map(lambda x: True,base_mask[layer])
      base_mask[layer] = change
  return base_mask

base_mask = FineTuningDepth(params,-1)
base_mask
  

{'Transformer/Transformer/embedding/embed': {'embeddings': True},
 'Transformer/Transformer/embedding/layer_norm': {'offset': True,
  'scale': True},
 'Transformer/Transformer/embedding/position_embeddings': {'position_embeddings': True},
 'Transformer/Transformer/transformer_block/layer_norm': {'offset': True,
  'scale': True},
 'Transformer/Transformer/transformer_block/layer_norm_1': {'offset': True,
  'scale': True},
 'Transformer/Transformer/transformer_block/multi_head_attention/linear': {'b': True,
  'w': True},
 'Transformer/Transformer/transformer_block/multi_head_attention/linear_1': {'b': True,
  'w': True},
 'Transformer/Transformer/transformer_block/multi_head_attention/linear_2': {'b': True,
  'w': True},
 'Transformer/Transformer/transformer_block/multi_head_attention/linear_3': {'b': True,
  'w': True},
 'Transformer/Transformer/transformer_block/transformer_mlp/linear': {'b': True,
  'w': True},
 'Transformer/Transformer/transformer_block/transformer_mlp/linear_1': {'b

## Training functions

In [None]:
def loss(params, rng, batch_token_ids, batch_labels):
    logits = roberta_classifier.apply(params, rng, batch_token_ids, training=True)
    labels = hk.one_hot(batch_labels, config['n_classes'])
    softmax_xent = -jnp.sum(labels * jax.nn.log_softmax(logits))
    softmax_xent /= labels.shape[0]
    return softmax_xent

@jax.jit
def accuracy(params, rng, batch_token_ids, batch_labels):
    predictions = roberta_classifier.apply(params, rng, batch_token_ids, training=False)
    return jnp.mean(jnp.argmax(predictions, axis=-1) == batch_labels)

@jax.jit
def update(params, rng, opt_state, batch_token_ids, batch_labels):
    batch_loss, grads = jax.value_and_grad(loss)(params, rng, batch_token_ids, batch_labels)
    updates, opt_state = opt.update(grads, opt_state)
    new_params = optax.apply_updates(params, updates)
    return new_params, opt_state, batch_loss

## setup optimiser with learning rate schedule

In [None]:
total_steps = config['n_epochs'] * (n_examples // config['batch_size'])
mask = FineTuningDepth(params,-1)
lr_scaling_schedule = optax.linear_schedule(
    1.0, 
    0.0, 
    total_steps, 
    transition_begin=0
    )
opt = optax.chain(
    optax.scale_by_schedule(lr_scaling_schedule),
    optax.clip_by_global_norm(max_norm=config['max_grad_norm']),
    optax.adam(learning_rate=config['learning_rate']),
)
opt =optax.masked(opt, mask)
opt_state = opt.init(params)

In [None]:
opt_state

MaskedState(inner_state=(ScaleByScheduleState(count=DeviceArray(0, dtype=int32)), EmptyState(), (ScaleByAdamState(count=DeviceArray(0, dtype=int32), mu={'Transformer/Transformer/embedding/embed': {'embeddings': DeviceArray([[0., 0., 0., ..., 0., 0., 0.],
             [0., 0., 0., ..., 0., 0., 0.],
             [0., 0., 0., ..., 0., 0., 0.],
             ...,
             [0., 0., 0., ..., 0., 0., 0.],
             [0., 0., 0., ..., 0., 0., 0.],
             [0., 0., 0., ..., 0., 0., 0.]], dtype=float32)}, 'Transformer/Transformer/embedding/layer_norm': {'offset': DeviceArray([0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
             0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
             0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
             0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
             0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
             0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 

## Wandb logging function

In [None]:
def measure_current_performance(params, n_examples=None, splits=('train', 'test'),track=0):
    # Load our training evaluation and test evaluation splits 
    if 'train' in splits:
        train_eval = load_dataset("train", training=False, batch_size=25, n_examples=n_examples)
        # Compute mean train accuracy
        train_accuracy = np.mean([
            accuracy(
                params, 
                rng, 
                encode_batch(train_eval_batch['text']), 
                train_eval_batch['label']
            )
            for train_eval_batch in train_eval
        ])
        print(f"\t Train acc: {train_accuracy:.3f}")
        if track:
          wandb.log({"Train acc": train_accuracy})
    
    if 'test' in splits:
        test_eval = load_dataset("test", training=False, batch_size=25, n_examples=n_examples)
        # Compute mean test accuracy
        test_accuracy = np.mean([
            accuracy(
                params, 
                rng,
                encode_batch(test_eval_batch['text']), 
                test_eval_batch['label'],
            )
            for test_eval_batch in test_eval
        ])
        print(f"\t Test accuracy: {test_accuracy:.3f}")
        if track:
          wandb.log({"Test acc": test_accuracy})


## Training loop with logging

In [None]:
os.environ["WANDB_API_KEY"] = 'bd0584875dd3c52df37cbd4565c0e22319f9cef6'
track=1
if track:
  print('----Tracking----')
  wandb_run = wandb.init(
      project="RoBERTa",
      entity="mo379",
  )
else:
  wandb_run = False
for epoch in range(config['n_epochs']):
  print(f'--------> epoch {epoch}')
  train_iter = iter(train)
  n_steps = int(np.ceil(n_examples/config['batch_size']))
  for step in range(n_steps):
      if step % 30 == 0:
          print(f"--> step {step}")
          measure_current_performance(params, n_examples=500,track=track)
      # Perform adam update
      next_batch = next(train_iter)
      next_batch['text'] = txt_transform(next_batch['text'])
      batch_token_ids = encode_batch(next_batch['text'])
      batch_labels = next_batch['label']
      params, opt_state, batch_loss = update(
          params, rng, opt_state, batch_token_ids, batch_labels
      )



----Tracking----


[34m[1mwandb[0m: Currently logged in as: [33mmo379[0m. Use [1m`wandb login --relogin`[0m to force relogin


--------> epoch 0
--> step 0




	 Train acc: 0.500
	 Test accuracy: 0.502
--> step 30
	 Train acc: 0.500
	 Test accuracy: 0.502
--> step 60
	 Train acc: 0.822
	 Test accuracy: 0.792
--> step 90
	 Train acc: 0.854
	 Test accuracy: 0.828
--> step 120
	 Train acc: 0.878
	 Test accuracy: 0.882
--> step 150
	 Train acc: 0.860
	 Test accuracy: 0.864
--> step 180
	 Train acc: 0.890
	 Test accuracy: 0.898
--> step 210
	 Train acc: 0.882
	 Test accuracy: 0.890
--> step 240
	 Train acc: 0.906
	 Test accuracy: 0.902
--> step 270
	 Train acc: 0.894
	 Test accuracy: 0.894
--> step 300
	 Train acc: 0.908
	 Test accuracy: 0.902
--> step 330
	 Train acc: 0.834
	 Test accuracy: 0.818
--> step 360
	 Train acc: 0.908
	 Test accuracy: 0.900
--> step 390
	 Train acc: 0.910
	 Test accuracy: 0.906
--> step 420
	 Train acc: 0.916
	 Test accuracy: 0.910
--> step 450
	 Train acc: 0.892
	 Test accuracy: 0.886
--> step 480
	 Train acc: 0.916
	 Test accuracy: 0.912
--> step 510
	 Train acc: 0.902
	 Test accuracy: 0.894
--> step 540
	 Train acc: 

ValueError: ignored