In [1]:
import jax
import jax.numpy as jnp

import flax
import flax.linen as nn

import optax



  PyTreeDef = type(jax.tree_structure(None))


In [2]:
def scaled_dot_product(q, k, v, mask=None):
    d_k = q.shape[-1]

    score = jnp.matmul(q, jnp.swapaxes(k, -2, -1))
    socre = score / jnp.sqrt(d_k)

    if mask is not None:
        socre = jnp.where(mask ==0, -9e15, score)
    
    attention = nn.softmax(score, axis=-1)
    values = jnp.matmul(attention, v)

    return values, attention

In [3]:
seq_len, d_k = 3, 2
main_rng, rand1 = jax.random.split(jax.random.PRNGKey(42))
qkv = jax.random.normal(rand1, (3, seq_len, d_k))
q, k, v = qkv[0], qkv[1], qkv[2]
values, attention = scaled_dot_product(q, k, v)
print("Q\n", q)
print("K\n", k)
print("V\n", v)
print("Values\n", values)
print("Attention\n", attention)

Q
 [[-0.6613315   0.70056266]
 [ 0.08239268 -1.7793142 ]
 [-0.04378588  1.0965251 ]]
K
 [[ 1.7257481   0.35568172]
 [ 1.3034704   1.2873708 ]
 [ 1.6871481  -0.5714404 ]]
V
 [[ 1.5129997   1.1050899 ]
 [ 0.27949408 -0.46224892]
 [-1.1003422  -1.1437942 ]]
Values
 [[ 0.40075538 -0.1672631 ]
 [-0.650403   -0.7712134 ]
 [ 0.45445058 -0.14728263]]
Attention
 [[0.2453795  0.62314427 0.13147624]
 [0.15692098 0.02888096 0.8141981 ]
 [0.23855191 0.6749896  0.0864585 ]]


In [4]:
class MultiheadAttention(nn.Module):
    embed_dim: int # Output dimension
    num_heads: int # Number of parllel heads (h)

    def setup(self):
        self.qkv_proj = nn.Dense(
            3*self.embed_dim,
            kernel_init=nn.initializers.xavier_uniform(),
            bias_init=nn.initializers.zeros
        )

        self.o_proj = nn.Dense(
            self.embed_dim,
            kernel_init=nn.initializers.xavier_uniform(),
            bias_init=nn.initializers.zeros
        )
    
    def __call__(self, x, mask=None):
        batch_size, seq_len, embed_dim  = x.shape

        qkv = self.qkv_proj(x)

        # Seperate Q, K, V
        qkv = qkv.reshape(batch_size, seq_len, self.num_heads, -1)
        qkv = qkv.transpose(0, 2, 1, 3) # batch, num_heads, seq_len, Dims
        q, k, v = jnp.split(qkv, 3, axis=-1)

        # Determine value outputs
        values, attention = scaled_dot_product(q, k, v, mask=mask)
        values = values.transpose(0, 2, 1, 3) # batch, seq_len, num_heads, Dims
        values = values.reshape(batch_size, seq_len, embed_dim)
        o = self.o_proj(values)

        return o, attention


In [5]:
## Test MultiheadAttention implementation
# Example features as input
main_rng, x_rng = jax.random.split(jax.random.PRNGKey(42))
x = jax.random.normal(x_rng, (3, 16, 128))
# Create attention
mh_attn = MultiheadAttention(embed_dim=128, num_heads=4)
# Initialize parameters of attention with random key and inputs
main_rng, init_rng = jax.random.split(jax.random.PRNGKey(42))
params = mh_attn.init(init_rng, x)['params']
# Apply attention with parameters on the inputs
out, attn = mh_attn.apply({'params': params}, x)
print('Out', out.shape, 'Attention', attn.shape)

del mh_attn, params

Out (3, 16, 128) Attention (3, 4, 16, 16)


In [6]:
from datasets import load_dataset

dataset = load_dataset("glue", "mrpc", split="train")


Downloading builder script:   0%|          | 0.00/7.78k [00:00<?, ?B/s]

Downloading metadata:   0%|          | 0.00/4.47k [00:00<?, ?B/s]

Downloading and preparing dataset glue/mrpc (download: 1.43 MiB, generated: 1.43 MiB, post-processed: Unknown size, total: 2.85 MiB) to /home/gholamhossin/local/labs/.hugging_face/glue/mrpc/1.0.0/dacbe3125aa31d7f70367a07a8a9e72a5a0bfeb5fc42e75c9db75b96da6053ad...




Downloading data files:   0%|          | 0/3 [00:00<?, ?it/s]

Downloading data: 0.00B [00:00, ?B/s]

Downloading data: 0.00B [00:00, ?B/s]

Downloading data: 0.00B [00:00, ?B/s]

Generating train split:   0%|          | 0/3668 [00:00<?, ? examples/s]

Generating validation split:   0%|          | 0/408 [00:00<?, ? examples/s]

Generating test split:   0%|          | 0/1725 [00:00<?, ? examples/s]

Dataset glue downloaded and prepared to /home/gholamhossin/local/labs/.hugging_face/glue/mrpc/1.0.0/dacbe3125aa31d7f70367a07a8a9e72a5a0bfeb5fc42e75c9db75b96da6053ad. Subsequent calls will reuse this data.


In [9]:
dataset

{'sentence1': 'Amrozi accused his brother , whom he called " the witness " , of deliberately distorting his evidence .',
 'sentence2': 'Referring to him as only " the witness " , Amrozi accused his brother of deliberately distorting his evidence .',
 'label': 1,
 'idx': 0}

In [10]:
from transformers import FlaxAutoModelForSequenceClassification, AutoTokenizer
model = FlaxAutoModelForSequenceClassification.from_pretrained("bert-base-uncased")


Downloading config.json:   0%|          | 0.00/570 [00:00<?, ?B/s]

Downloading flax_model.msgpack:   0%|          | 0.00/418M [00:00<?, ?B/s]

Some weights of the model checkpoint at bert-base-uncased were not used when initializing FlaxBertForSequenceClassification: {('cls', 'predictions', 'transform', 'LayerNorm', 'scale'), ('cls', 'predictions', 'bias'), ('cls', 'predictions', 'transform', 'dense', 'kernel'), ('cls', 'predictions', 'transform', 'dense', 'bias'), ('cls', 'predictions', 'transform', 'LayerNorm', 'bias')}
- This IS expected if you are initializing FlaxBertForSequenceClassification from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).
- This IS NOT expected if you are initializing FlaxBertForSequenceClassification from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).
Some weights of FlaxBertForSequenceClassification were not initialized from the model checkpoint at bert-base-unca

In [25]:
from datasets import Dataset
import numpy as np
ds = dataset.with_format('tf')

In [27]:
ds[0]

{'sentence1': <tf.Tensor: shape=(), dtype=string, numpy=b'Amrozi accused his brother , whom he called " the witness " , of deliberately distorting his evidence .'>,
 'sentence2': <tf.Tensor: shape=(), dtype=string, numpy=b'Referring to him as only " the witness " , Amrozi accused his brother of deliberately distorting his evidence .'>,
 'label': <tf.Tensor: shape=(), dtype=int64, numpy=1>,
 'idx': <tf.Tensor: shape=(), dtype=int64, numpy=0>}