In [1]:
import jax
import jax.numpy as jnp
import flax
from entmax_jax import sparsemax

In [None]:
from transformers import FlaxElectraForSequenceClassification, ElectraTokenizerFast, TensorType

tokenizer = ElectraTokenizerFast.from_pretrained('google/electra-small-discriminator')
model = FlaxElectraForSequenceClassification.from_pretrained('google/electra-small-discriminator')
module = model.module
params = model.params
full_params = {"params": params}

In [5]:
@jax.jit
def electra_attn_weights(input_ids, attention_mask, token_type_ids):
    position_ids = jnp.arange(jnp.atleast_2d(input_ids).shape[-1])
    return module.apply(
        full_params,
        input_ids, 
        attention_mask, 
        token_type_ids, 
        position_ids,
        return_dict=False, 
        output_attentions=True,
        unnorm_attention=True)[1]

encodings = tokenizer(["test", "double test test"], return_tensors=TensorType.JAX, padding=True)
weights = electra_attn_weights(**encodings)
print(flax.linen.softmax(weights[-1][:, -1, 0]))
#print(sparsemax(weights, axis=-1)[:, -1, -1])

[[0.25902152 0.16806439 0.253755   0.15957955 0.15957955]
 [0.25383192 0.16607553 0.1678231  0.16549358 0.24677587]]


In [9]:
[[0.48436862 0.05180468 0.4638266  0.         0.        ]
 [0.45878488 0.03455498 0.04502275 0.03104479 0.4305927 ]]

ElectraConfig {
  "architectures": [
    "ElectraForPreTraining"
  ],
  "attention_probs_dropout_prob": 0.1,
  "embedding_size": 128,
  "hidden_act": "gelu",
  "hidden_dropout_prob": 0.1,
  "hidden_size": 256,
  "initializer_range": 0.02,
  "intermediate_size": 1024,
  "layer_norm_eps": 1e-12,
  "max_position_embeddings": 512,
  "model_type": "electra",
  "num_attention_heads": 4,
  "num_hidden_layers": 12,
  "pad_token_id": 0,
  "position_embedding_type": "absolute",
  "summary_activation": "gelu",
  "summary_last_dropout": 0.1,
  "summary_type": "first",
  "summary_use_proj": true,
  "transformers_version": "4.7.0.dev0",
  "type_vocab_size": 2,
  "vocab_size": 30522
}



In [6]:
params = model.params
encodings = tokenizer("test", return_tensors=TensorType.JAX)

def attn_sum(params, input_ids, attention_mask, token_type_ids):
    weights = model(input_ids, attention_mask, token_type_ids, params=params, output_attentions=True)[2]
    return jnp.sum(jnp.mean(weights, axis=1))

grad_fn = jax.jit(jax.grad(attn_sum))
grad_fn(params, **encodings)

FrozenDict({
    embeddings: {
        layer_norm: {
            beta: DeviceArray([-2.85541546e-09,  3.65660324e-09,  1.03834719e-10,
                         -3.57843533e-09,  1.29741595e-09, -2.36547404e-09,
                         -3.13269233e-09, -3.65250896e-09,  4.30338609e-09,
                         -3.91636146e-09, -9.91454585e-10, -7.13687509e-09,
                          5.39629097e-09, -1.48370338e-09,  4.28959446e-09,
                         -3.12016502e-09,  2.57528576e-09,  1.03404652e-09,
                          2.74821876e-09,  5.43943601e-09,  2.31653308e-10,
                         -1.35383549e-09,  2.75576917e-09, -3.88164345e-09,
                          8.97568686e-11,  9.25719501e-11,  6.78937839e-09,
                          3.21812887e-09, -1.11014842e-09,  5.98481620e-09,
                         -4.93937558e-10,  1.63656444e-09,  2.22316876e-09,
                          8.15678991e-09, -6.66407485e-10, -2.23727192e-09,
                         -1.5