### $\textbf{Task: Implement a transformer model from scratch}$

### Motivating Attention

Recurrent architectures generate a series of internal states $h_{t}$ from the previous internal state $h_{t-1}$ and the value of the sequence at time $t$, this is inherently sequential and becomes a bottleneck for long sequences, such as large text corpus. Furthermore it struggles to model long range dependencies in the data, and relevance is dominanted by the recent past. The outputs of attention contains individual token information and contextual relationships between tokens too.

Attention calculates a contextualised meaning of a token, which is beneficial as the same token can have different meanings in different contexts. For example:
"It's not **fair**" vs "He has **fair** hair".

### Attention Mechanism

At the core of the transformer model is the attention mechanism, which allows for the modelling of dependencies without hinderance from distance apart in the sequence. 

Simply: **Attention aggregating information from all other tokens in the sequence, weighted by their relevance to the current token.**

Essentially, the attention mechanism comprises of three matrices:
 - $Q \in \mathbb{R}^{NQ \times d_{k}}$ the query matrix.
 - $K \in \mathbb{R}^{NK \times d_{k}}$ the key matrix.
 - $V \in \mathbb{R}^{NK \times d_{v}}$ the value matrix.

With each key (row in key matrix), there is a corresponding value (row in the value matrix). The formula of attention is rather intuiative, considering a single query $\vec{q}$ represented as row vector of length $d_{k}$, we calculate the dot-product between the query and every key in the key matix: $\vec{q} \cdot K^{T}$, therefore over a batch of $NQ$ queries assembled in a matrix $Q$ our pair-wise similarities is given by $Q\cdot K^{T} \in \mathbb{R}^{NQ \times NK}$. So, $\left(Q\cdot K^{T} \right)_{ij}$ is the similarity between query $i$ and key $j$. 

We then normalise over the number of values we have to ensure our similarity values are of similar variance as to when they were inputted, mainly it helps for numerical stability of the softmax function and ensuring we have non-negligible gradients (If any of the values are significantly larger than others then it basically diminishes all others to 0 regardless of their true value). We softmax over each row of this similarity matrix, to get the normalised similarities, so the query-key similarities of each row sums to $1$. These are effectively weightings of the corresponding values in the value matrix $V$.

$$ \texttt{att}(Q, K, V) = \texttt{softmax}\left(\frac{1}{\sqrt{d_{k}}} Q \cdot K^{T} \right) V. $$

For example, if a key and query are highly aligned, then the value normalised alignment will be close to $1$ and the resulting value will be similar to the value associated with the initial key.

Sometimes, the softmax is decomposed into two steps, first we exponentiate all the elements and then we normalise each row. However, we can apply a mask to the weights before doing the softmax, zeroing some elements. e.g. if we make the masking matrix all ones below the diagonal, then we make the attention mechanism causal. That is, only queries appearing before the key can be used for attention. Our model cannot use future tokens in predicting the present.

Finally, after the softmax, we can apply dropout. This effectively removes certain key query similarities, which could help our model by removing dominant key query similarities, so our model infers relationships between less similar key query pairs.


### Attention Implementation

The attention mechanism learns how to attend the information from one sequence to the information in another sequence. This is known as cross-attention. In the case that we want to learn about relationships within a sequence itself, these two sequences are the same, this is known as self-attention. 

In [2]:
import flax.linen as nn
import jax.numpy as jnp
import jax

from typing import Optional

class AttentionMechanism(nn.Module):
    d_key : int
    d_value : int
    dropout_rate : float = 0.0

    @nn.compact
    def __call__(self, 
        sequence : jnp.ndarray, 
        target_sequence : jnp.ndarray,
        mask : Optional[jnp.ndarray] = None, 
        dropout_rng = None
    ):
        # We project our sequence to keys, queries and values with learned projections
        query_matrix = nn.Dense(features=self.d_key, kernel_init=nn.initializers.lecun_normal())(target_sequence) # (NQ = len(target_sequence), d_key)
        key_matrix   = nn.Dense(features=self.d_key, kernel_init=nn.initializers.lecun_normal())(sequence)     # (NK = len(sequence), d_key)
        value_matrix = nn.Dense(features=self.d_value, kernel_init=nn.initializers.lecun_normal())(sequence) # (NK = len(sequence), d_value)

        un_normalised_weights = jnp.exp(query_matrix @ key_matrix.T)
        if (mask is not None):
            un_normalised_weights = mask * un_normalised_weights
        un_normalised_weights += 10e-5
        weights = un_normalised_weights / jnp.sum(un_normalised_weights, axis = 0)
        if self.dropout_rate > 0.0 and dropout_rng is not None:
            dropout_mask = jax.random.bernoulli(dropout_rng, shape=weights.shape, p = 1 - self.dropout_rate)
            weights *= dropout_mask / (1 - self.dropout_rate)
        return weights @ value_matrix

In [3]:
model = AttentionMechanism(d_key = 10, d_value=10)
rng, init_rng = jax.random.split(jax.random.PRNGKey(42), 2)
params = model.init(init_rng, jnp.ones((3, 4)), jnp.ones((3, 4)))
model.apply(params, jnp.ones((3, 4)), jnp.ones((3, 4)))

Array([[-0.6498003 , -0.6035281 ,  0.1453905 ,  2.3583808 , -0.22132717,
         0.10294747,  1.2881615 , -0.33101836, -0.62982154, -0.10315502],
       [-0.6498003 , -0.6035281 ,  0.1453905 ,  2.3583808 , -0.22132717,
         0.10294747,  1.2881615 , -0.33101836, -0.62982154, -0.10315502],
       [-0.6498003 , -0.6035281 ,  0.1453905 ,  2.3583808 , -0.22132717,
         0.10294747,  1.2881615 , -0.33101836, -0.62982154, -0.10315502]],      dtype=float32)

### Computational Cost of Attention

As with most deep learning architectures, the computational cost of attention lies in the matrix multiplications. 

- Calculating the key matrix: $O(NK \times d_{model} \times d_{key})$
- Calculating the value matrix: $O(NK \times d_{model} \times d_{value})$
- Calculating the query matrix: $O(NQ \times d_{model} \times d_{key})$
- $Q \times K^{T},~~((NQ \times d_{key}) \cdot (NK \times d_{key})^{T})$, which has a cost $O(NQ \times d_{key} \times NK)$.
- $\left( Q \times K^{T} \right) \times V,~~ ((NQ \times NK) \cdot (NK \times d_{value}))$, which has a cost $O(NQ \times NK \times d_{value})$.

So the total computational cost is $O(NQ \times NK \times (d_{key} + d_{value}))$.

The memory cost is the size of the key and value matrices, and the weight matrix: $O(NK \times d_{key} + NK \times d_{value} + NQ \times d_{key} + NQ \times NK)$.

## Multi-Head Attention

In practice, instead of having a single Attention head, the model can better learn different semantic and syntactic relationships by projecting the queries, keys and values to a lower dimension, and performing attention there, before concatenating the results and doing a final linear layer.

In [4]:
class MultiHeadedAttention(nn.Module):
    # For multi head attention, d_key = d_value = d_model, which is then reduced in each attention head
    d_model : int
    num_heads : int
    dropout_rate : float = 0.0

    @nn.compact
    def __call__(self,
        sequence : jnp.ndarray,
        target_sequence : jnp.ndarray,
        dropout_rng = None
    ):
        
        attention_outputs = []
        for head_i in range(self.num_heads):
            single_head_attention = AttentionMechanism(d_key= self.d_model // self.num_heads, d_value = self.d_model // self.num_heads, dropout_rate=self.dropout_rate)(sequence, target_sequence, dropout_rng)
            attention_outputs.append(single_head_attention)
        attention_outputs = jnp.concatenate(attention_outputs, axis=1)
        # Final projection of concatenated outputs
        return nn.Dense(features=self.d_model)(attention_outputs) # (NQ x self.d_model)

In [5]:
multi_head = MultiHeadedAttention(d_model = 32, num_heads=8)
rng, init_rng = jax.random.split(jax.random.PRNGKey(42), 2)
params = multi_head.init(init_rng, jnp.ones((10, 32)), jnp.ones((10, 32)))
multi_head.apply(params, jnp.ones((10, 32)), jnp.ones((10, 32)))

Array([[ 0.09218917,  1.021213  ,  0.38755322, -1.438241  ,  0.32310086,
        -0.5021071 ,  1.367448  ,  0.3488834 , -0.04968318,  1.2304572 ,
        -0.5584813 , -0.1837434 ,  1.2924222 ,  0.67134094,  0.86011845,
         1.0973927 ,  1.7717434 ,  0.25737727,  0.04999864, -0.06061977,
         0.12868129, -1.2585343 ,  0.25587434,  1.3699903 , -0.21463393,
        -0.59999275, -1.5419455 ,  0.14009482,  0.5536306 ,  0.3260527 ,
        -0.6475275 ,  0.336717  ],
       [ 0.09218917,  1.021213  ,  0.38755322, -1.438241  ,  0.32310086,
        -0.5021071 ,  1.367448  ,  0.3488834 , -0.04968318,  1.2304572 ,
        -0.5584813 , -0.1837434 ,  1.2924222 ,  0.67134094,  0.86011845,
         1.0973927 ,  1.7717434 ,  0.25737727,  0.04999864, -0.06061977,
         0.12868129, -1.2585343 ,  0.25587434,  1.3699903 , -0.21463393,
        -0.59999275, -1.5419455 ,  0.14009482,  0.5536306 ,  0.3260527 ,
        -0.6475275 ,  0.336717  ],
       [ 0.09218917,  1.021213  ,  0.38755322, -1.4382

### Parallelism of Attention Mechanism

In recurrent (autoregressive) architectures, we need to process the sequence sequentially as we require the internal state to be updates with the previous outputs/inputs to ensure our model is casual. However attention mechanisms calculuate attention to all tokens in the sequence at once, which can be made casual with masking, meaning we can input entire batches and process them at once. **Ultimately the attention mechanism, and therefore the transformer architecture, does not depend on sequential computation**.

### Self Attention and Cross Attention

Self-Attention and Cross-Attention refer to how we construct the key, query and value matrices in our model. In self-attention the keys, queries and values are all learned from the same sequence. The idea is to see how parts of the input sequence attend to different parts of the input sequence. This is used to learn dependencies between different parts of the input sequence.

Cross-Attention is when the queries come from one sequence but the keys and values come from another. This is used to see how tokens from one sequence attend to tokens from another sequence, learning relationships between elements in the two sequences, and it is commonly used in the decoder part of a Transformer model in sequence-to-sequence tasks, such as machine translation, where the decoder learns how to attend to the encoder's output while generating a target sequence. 

### Machine Translation Example:



# Transformer Architecture

## Encoder Block
With more attention blocks in series our embeddings become more and more contextualised.

Our encoder block uses residual connections to help mitigate the vanishing gradient problem in our deep architecture, it improves the flow of gradients and makes the updates more substantial deep into the network.

In [6]:
class EncoderBlock(nn.Module):
    d_model : int
    num_heads : int
    dropout_rate : float = 0.0
    
    @nn.compact
    def __call__(self,
        sequence : jnp.ndarray, # (NK=NQ, d_model) In the encoder, we only do self-attention, as such we only need one input sequence, which attends to itself
        dropout_rng = None
    ):
        attention_output = MultiHeadedAttention(d_model=self.d_model, num_heads=self.num_heads, dropout_rate=self.dropout_rate)(sequence, sequence, dropout_rng) # (NK x d_model)
        residual = nn.LayerNorm()(attention_output + sequence)
        linear = nn.Dense(features = self.d_model)(residual)
        return nn.LayerNorm()(residual + linear)

In [7]:
encoder_block = EncoderBlock(d_model = 32, num_heads = 8)
rng, init_rng = jax.random.split(jax.random.PRNGKey(42), 2)
params = encoder_block.init(init_rng, jnp.ones((10, 32)))
encoder_block.apply(params, jnp.ones((10, 32)))

Array([[-1.6010164 , -0.74460006, -0.37307504, -0.81092244,  0.15336768,
        -0.06796968,  0.47183788,  0.26438254,  0.1073863 ,  1.5642092 ,
        -0.00715235,  0.81226546, -0.55497384,  0.18102111, -2.0238996 ,
         0.41321406, -0.19702047,  2.1614182 , -1.353574  ,  0.042718  ,
        -1.0291138 , -0.6032628 ,  0.9482072 ,  0.69363105, -0.75029397,
        -1.5485976 ,  0.68653333,  0.18166786, -0.06982882, -0.63098276,
         1.7092276 ,  1.9751959 ],
       [-1.6010164 , -0.74460006, -0.37307504, -0.81092244,  0.15336768,
        -0.06796968,  0.47183788,  0.26438254,  0.1073863 ,  1.5642092 ,
        -0.00715235,  0.81226546, -0.55497384,  0.18102111, -2.0238996 ,
         0.41321406, -0.19702047,  2.1614182 , -1.353574  ,  0.042718  ,
        -1.0291138 , -0.6032628 ,  0.9482072 ,  0.69363105, -0.75029397,
        -1.5485976 ,  0.68653333,  0.18166786, -0.06982882, -0.63098276,
         1.7092276 ,  1.9751959 ],
       [-1.6010164 , -0.74460006, -0.37307504, -0.8109

### Attention Invariance and Positional Encoding

Because the attention mechanism works via dot products across the entire key set, permuting the keys and associated values, makes no difference on the output. Furthermore permuting the queries will only permute the resulting attention matrix (`equivariance'). Therefore **attention models are oblivious to the relative positioning of tokens in the sequence**. This is highly suitable for stationary sequences - where the distribution of values is independent of position in the sequence. For many applications, such as natural language processing, the absolute positioning of tokens has important semantic and syntactic meaning, which necessitates positional encoding in our tokens.

The position encoding described in the original transformer model is given by:
- If feature_index is even $PE(position, feature\_index) = PE(position, 2i) = \sin(\frac{pos}{10000^{2i/d_{model}}})$.

- If feature_index is odd $PE(position, feature\_index + 1) = PE(position, 2i + 1) = \cos(\frac{pos}{10000^{2i/d_{model}}})$.

This is used as there is a linear relationship between the positional encodings, so they are easy to learn. The very low frequency means it is unlikely to repeat and multiple tokens don't have the same positional encoding.

$$ \sin(a + b) = \sin(a)\cos(b) + \sin(b)\cos(a) $$
$$ PE(k + n, 2i) = \sin(\frac{k}{\omega} + \frac{n}{\omega}) = \sin(\frac{k}{\omega})\cos(\frac{n}{\omega}) + \sin(\frac{n}{\omega})\cos(\frac{k}{\omega}) = PE(k, 2i)w_{1} + PE(n, 2i)w_{2}. $$

Positional Encoding is done before query/keys are constructed and done to the initial token embeddings.

In [8]:
class PositionalEncoder(nn.Module):
    num_data : int
    num_features : int

    @nn.compact
    def __call__(self):
        frequencies = jnp.array([[pos/(10_000 ** (2 * feat / self.num_features)) for feat in range(self.num_features)] for pos in range(self.num_data)])
        even_positions = jnp.sin(frequencies)
        odd_positions = jnp.cos(frequencies)
        positional_encodings = even_positions
        positional_encodings.at[:, 1::2].set(odd_positions[:, 1::2])
        return positional_encodings

In the encoder, the only attention is self-attention over the input sequence, here our model learns patterns in the input sequence. Additionally, we need to project the data to the model dimensionality first and then add our positional encodings.

In [9]:
class Encoder(nn.Module):
    d_model : int
    num_heads : int
    num_encoder_blocks : int
    dropout_rate : float = 0.0

    @nn.compact
    def __call__(self,
        sequence : jnp.ndarray, # (NQ=NK, embedding_dimension)
        dropout_rng = None
    ):
        project_sequence = nn.Dense(features=self.d_model)(sequence)
        positional_encodings = PositionalEncoder(num_data = sequence.shape[0], num_features = self.d_model)()
        project_sequence += positional_encodings
        for block in range(self.num_encoder_blocks):
            if dropout_rng is not None:
                dropout_rng, block_rng = jax.random.split(dropout_rng, 2)
                x = EncoderBlock(d_model = self.d_model, num_heads = self.num_heads, dropout_rate = self.dropout_rate)(project_sequence, block_rng)
            else:
                x = EncoderBlock(d_model = self.d_model, num_heads = self.num_heads, dropout_rate = self.dropout_rate)(project_sequence)
        return x

In [10]:
encoder = Encoder(d_model = 32, num_heads = 8, num_encoder_blocks=6, dropout_rate=0.0)
rng, init_rng = jax.random.split(jax.random.PRNGKey(42), 2)
params = encoder.init(init_rng, jnp.ones((10, 128)))
encoder.apply(params, jnp.ones((10, 128)))

Array([[ 0.2270196 , -0.88147086,  1.7586254 ,  0.61025095, -1.4935668 ,
        -0.09663835, -0.25542971, -1.6849537 , -0.85669243,  1.7140261 ,
         1.0995613 , -1.9956753 ,  0.37985593, -1.5971788 ,  0.01673846,
         1.9369373 ,  0.7516392 , -0.56977063,  0.01491532, -0.5858998 ,
        -0.39314574,  0.04456951,  0.47563976, -0.17209879, -0.85661393,
         0.47149247, -0.1941591 ,  0.2547242 ,  1.9583863 , -0.23300745,
        -0.49226698,  0.6441873 ],
       [ 0.5688276 , -0.7960301 ,  1.6817492 ,  0.69646233, -1.405921  ,
         0.48810703, -0.15471283, -1.7160313 , -0.9328121 ,  1.3078345 ,
         1.1444054 , -1.9570963 , -0.01601211, -1.9795674 ,  0.17113669,
         1.7458769 ,  0.727223  , -0.6600679 ,  0.1649453 , -0.7671649 ,
        -0.01673924,  0.15837243,  0.5285803 , -0.65436304, -0.62370676,
         0.1977588 ,  0.0137896 ,  0.3212364 ,  2.0461264 ,  0.21390757,
        -0.8088062 ,  0.3126921 ],
       [ 0.6432511 , -0.68558127,  1.6664013 ,  0.7591

## Decoder

Now let's understand and implement the decoder block. In the original transformer architecture, the decoder contains two attention mechanisms, a self-attention of the produced output sequence so far, to itself, which learns how the output embeddings relate to each other. And a cross-attention between the output tokens and the output of the encoder, which how tokens in our output sequence relate to the input sequence. Therefore, it still produces translations in a sequential, auto-regressive manner. In the cross-attention, the keys and values are from the english sequence and the queries from our german sequence, the output therefore is what the meaning of each word is, in the english context.

In the decoder masking is used as we do not know the future predicted tokens, its used in inference for self attention.

In [11]:
class DecoderBlock(nn.Module):

    d_model : int
    num_heads : int
    dropout_rate : float = 0.0

    @nn.compact
    def __call__(self,
        encoder_output,
        output_sequence, 
        dropout_rng = None
    ):
        output_self_attention = MultiHeadedAttention(d_model=self.d_model, num_heads=self.num_heads, dropout_rate=self.dropout_rate)(output_sequence, output_sequence, dropout_rng) # (NK x d_model)
        residual = nn.LayerNorm()(output_self_attention + output_sequence)
        
        cross_attention = MultiHeadedAttention(d_model=self.d_model, num_heads=self.num_heads, dropout_rate=self.dropout_rate)(encoder_output, output_self_attention, dropout_rng)
        residual = nn.LayerNorm()(cross_attention + residual)

        linear = nn.Dense(features = self.d_model)(residual)
        return nn.LayerNorm()(residual + linear)

In [12]:
decoder_block = DecoderBlock(d_model = 32, num_heads=8, dropout_rate=0.0)
rng, init_rng = jax.random.split(jax.random.PRNGKey(42), 2)
params = decoder_block.init(init_rng, jnp.ones((1, 32)), jnp.ones((10, 32)))
decoder_block.apply(params, jnp.ones((10, 32)), jnp.ones((1, 32)))

Array([[ 0.8705498 , -0.29904237,  1.065743  , -0.92596537,  0.8594881 ,
        -0.06898125,  0.43973774, -0.40434396,  0.88021624, -0.5962909 ,
        -1.3918655 ,  2.4014132 ,  0.4867272 ,  0.6781453 ,  1.2331144 ,
        -2.0076938 , -0.03978869,  0.5281033 , -0.2670866 ,  1.6258852 ,
        -1.3171986 ,  0.91133046,  1.0324557 , -1.4080709 , -1.2728466 ,
        -0.5505679 , -0.05560639, -0.73855853,  0.5119102 , -0.9049106 ,
        -0.6274499 , -0.64855194]], dtype=float32)

In [13]:
class Decoder(nn.Module):
    d_model : int
    num_heads : int
    num_decoder_blocks : int
    dropout_rate : float = 0.0

    @nn.compact
    def __call__(self,
        encoder_output,
        output_sequence,
        dropout_rng = None
    ):
        projected_output = nn.Dense(features=self.d_model)(output_sequence)
        positional_encodings = PositionalEncoder(num_data=output_sequence.shape[0], num_features=self.d_model)()
        projected_output += positional_encodings
        x = projected_output
        for decoder_block in range(self.num_decoder_blocks):
            if (dropout_rng is not None and self.dropout_rate != 0.0):
                dropout_rng, block_rng = jax.random.split(dropout_rng, 2)
                x = DecoderBlock(d_model=self.d_model, num_heads=self.num_heads, dropout_rate=self.dropout_rate)(encoder_output, x, block_rng)
            else:
                x = DecoderBlock(d_model=self.d_model, num_heads=self.num_heads, dropout_rate=self.dropout_rate)(encoder_output, x)
        return x

In [14]:
decoder = Decoder(d_model=32, num_heads=8, num_decoder_blocks=6, dropout_rate=0.0)
rng, init_rng = jax.random.split(jax.random.PRNGKey(42), 2)
params = decoder.init(init_rng, jnp.ones((10, 32)), jnp.ones((2, 32)))
decoder.apply(params, jnp.ones((10, 32)), jnp.ones((2, 32)))

Array([[-2.46774   ,  0.30575985, -1.2096788 , -0.01589864, -1.4317238 ,
         0.15191452,  0.59750885,  1.4253247 , -0.97455406,  1.2900374 ,
        -0.16072749,  1.1411221 ,  0.58991116,  0.14383507,  1.805051  ,
        -0.24406984,  1.091023  , -0.10063185,  1.3763161 ,  0.39037886,
        -0.41555667,  0.45323968,  0.4349356 , -1.3794247 ,  0.06797572,
         0.2378329 , -0.25623277,  1.2141205 , -1.3303334 , -0.64375865,
        -1.5839705 , -0.5019853 ],
       [-2.6987092 , -0.50889957, -0.7936967 ,  1.5718756 ,  0.22204833,
         0.21450977,  0.7557413 ,  1.8009727 , -0.7372998 ,  0.48573074,
         1.0401368 ,  0.95349705,  0.11769154,  0.73372597, -0.47389382,
        -0.4456925 , -1.3872129 , -0.9461556 ,  1.128169  , -1.0824682 ,
         0.99234533,  0.8792735 ,  0.1793544 , -1.1905773 , -0.37299743,
         0.80755424, -0.20205463, -0.2735211 ,  1.4139537 , -1.2048684 ,
        -0.04597928, -0.9325538 ]], dtype=float32)

## Transformer Class

In [15]:
class Transformer(nn.Module):
    d_model : int
    num_heads : int
    num_encoder_blocks : int
    num_decoder_blocks : int
    dropout_rate : float
    output_dim : int
    
    @nn.compact
    def __call__(self,
        input_sequence,
        output_sequence,
        dropout_rng = None
    ):
        encoder_output = Encoder(d_model=self.d_model, num_encoder_blocks=self.num_encoder_blocks, num_heads=self.num_heads, dropout_rate=self.dropout_rate)(input_sequence)
        decoder_output = Decoder(d_model=self.d_model, num_decoder_blocks=self.num_decoder_blocks, num_heads=self.num_heads, dropout_rate=self.dropout_rate)(encoder_output, output_sequence)
        return nn.softmax(nn.Dense(features=self.output_dim)(decoder_output))

In [18]:
transformer = Transformer(d_model=128, num_heads=8, num_encoder_blocks=6, num_decoder_blocks=6, dropout_rate=0.0, output_dim=16)
rng, init_rng = jax.random.split(jax.random.PRNGKey(42), 2)
input_sequence_encodings = jnp.ones((10, 12))
output_generated_so_far = jnp.ones((2, 16))
prediction_token = jnp.zeros((1, 16))
output_generated_so_far = jnp.concatenate([output_generated_so_far, prediction_token], axis=0)
params = transformer.init(init_rng, input_sequence_encodings, output_generated_so_far)
transformer.apply(params, input_sequence_encodings, output_generated_so_far)[-1, :]

Array([0.03418443, 0.03028339, 0.03722412, 0.02091815, 0.02253244,
       0.03425919, 0.03262088, 0.06958006, 0.04557904, 0.28960752,
       0.02728707, 0.15572648, 0.09763368, 0.00633427, 0.03984687,
       0.05638238], dtype=float32)