<a href="https://colab.research.google.com/github/Edwin372/BasicHandsOnML/blob/main/SentencePiece_and_BPE.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [None]:
!pip install trax

Looking in indexes: https://pypi.org/simple, https://us-python.pkg.dev/colab-wheels/public/simple/
Collecting trax
  Downloading trax-1.4.1-py2.py3-none-any.whl (637 kB)
[K     |████████████████████████████████| 637 kB 4.0 MB/s 
[?25hCollecting funcsigs
  Downloading funcsigs-1.0.2-py2.py3-none-any.whl (17 kB)
Collecting tensorflow-text
  Downloading tensorflow_text-2.9.0-cp37-cp37m-manylinux_2_17_x86_64.manylinux2014_x86_64.whl (4.6 MB)
[K     |████████████████████████████████| 4.6 MB 60.6 MB/s 
Collecting tensorflow<2.10,>=2.9.0
  Downloading tensorflow-2.9.1-cp37-cp37m-manylinux_2_17_x86_64.manylinux2014_x86_64.whl (511.7 MB)
[K     |████████████████████████████████| 511.7 MB 6.1 kB/s 
Collecting flatbuffers<3.0,>=1.12
  Downloading flatbuffers-1.12-py2.py3-none-any.whl (15 kB)
Collecting tensorflow-estimator<2.10.0,>=2.9.0rc0
  Downloading tensorflow_estimator-2.9.0-py2.py3-none-any.whl (438 kB)
[K     |████████████████████████████████| 438 kB 70.3 MB/s 
Collecting keras<2.10.

<a name="1.2"></a>
## Part 1.2  Trax Details
The goal in this notebook is to override a few routines in the Trax classes with our own versions. To maintain their functionality in a full Trax environment, many of the details we might ignore in example version of routines will be maintained in this code. Here are some of the considerations that may impact our code:
* Trax operates with multiple back-end libraries, we will see special cases that will utilize unique features.
* 'Fancy' numpy indexing is not supported in all backend environments and must be emulated in other ways.
* Some operations don't have gradients for backprop and must be ignored or include forced re-evaluation.

Here are some of the functions we may see:
* Abstracted as `fastmath`, Trax supports multiple backends such as [Jax](https://github.com/google/jax) and [Tensorflow2](https://github.com/tensorflow/tensorflow)
* [tie_in](https://jax.readthedocs.io/en/latest/_autosummary/jax.lax.tie_in.html): Some non-numeric operations must be invoked during backpropagation. Normally, the gradient compute graph would determine invocation but these functions are not included. To force re-evaluation, they are 'tied' to other numeric operations using tie_in.
* [stop_gradient](https://trax-ml.readthedocs.io/en/latest/trax.fastmath.html): Some operations are intentionally excluded from backprop gradient calculations by setting their gradients to zero.
* Below we will execute `from trax.fastmath import numpy as np `, this uses accelerated forms of numpy functions. This is, however a *subset* of numpy

In [None]:
import trax
import os
from trax import layers as tl
from trax import fastmath
from trax.fastmath import numpy as np

In [None]:
import jax
fastmath.use_backend('jax')

<contextlib._GeneratorContextManager at 0x7f9e7f40cc10>

In [None]:
from jax._src.lax.lax import tie_in
from trax.layers import (
    # tie_in, #tie one non-numeric operation to other numeric operation to force reevaluation
    length_normalized,
    apply_broadcasted_dropout,
    look_adjacent,
    permute_via_gather,
    permute_via_sort
)



In some applications some values are masked. This can be used, for example to exclude results that occur later in time (causal) or to mask padding or other inputs.

he routine below mask_self_attention implements a flexible masking capability. The masking is controlled by the information in q_info and kv_info.

In [None]:

def mask_self_attention(
    dots, q_info, kv_info, casual=True, exclude_self=True, masked=False
):
 """
 Perform masking for self-attention
 q_info: Query-associated metadata for masking
 kv_info: Key-associated metadata for masking
 """
 if casual: 
   mask = fastmath.lt(q_info, kv_info).astype(np.float32)
   dots = dots - 1e9 * mask
 if exclude_self: 
   mask = np.equal(q_info, kv_info).astype(np.float32)
   dots = dots - 1e9 * mask
 if masked: 
   zeros_like_kv_info = tie_in(kv_info, np.zeros_like(kv_info))
   mask = fastmath.lt(kv_info, zeros_like_kv_info).astype(np.float32)
   dots = dots - 1e9 * mask

 return dots

 

In [None]:
a = np.arange(9).reshape(3,3)
a

DeviceArray([[0, 1, 2],
             [3, 4, 5],
             [6, 7, 8]], dtype=int32)

In [None]:
b = np.arange(8,-1,-1).reshape(3,3)
b

DeviceArray([[8, 7, 6],
             [5, 4, 3],
             [2, 1, 0]], dtype=int32)

In [None]:
fastmath.lt(a,b).astype(np.float32)

DeviceArray([[1., 1., 1.],
             [1., 0., 0.],
             [0., 0., 0.]], dtype=float32)

In [None]:
np.zeros_like(b)

DeviceArray([[0, 0, 0],
             [0, 0, 0],
             [0, 0, 0]], dtype=int32)

In [None]:
zeros_like_kv_info = tie_in(a, np.zeros_like(b))

In [None]:
zeros_like_kv_info

DeviceArray([[0, 0, 0],
             [0, 0, 0],
             [0, 0, 0]], dtype=int32)

In [None]:
def our_softmax(x, passthrough=False):
  """softmax with passthrough"""
  logsumexp = fastmath.logsumexp(x, axis=-1, keepdims=True)
  o = np.exp(x - logsumexp)
  if passthrough:
    return (x, np.zeros_like(logsumexp))
  else:
    return (o, logsumexp)

In [None]:
a = np.array([1.0, 2.0, 3.0, 4.0])
sma = np.exp(a) / sum(np.exp(a))
print(sma)
sma2, a_logsumexp = our_softmax(a, passthrough=True)
print(sma2)
print(a_logsumexp)

[0.0320586  0.08714432 0.23688282 0.6439142 ]
[1. 2. 3. 4.]
[0.]
