In [1]:

from flax import nnx


In [30]:
#### TODO redo a graphormer achitecture (with a specific envelop term)

"""
Model architecture backbone
It will be two things : 

Essentially, it will be the backbone of the transformer world model

We will try to code the transformer in jax (flax)

Homemade version of the transformer

"""

from flax import nnx
import jax
import jax.numpy as jnp
import einops


class FeedForward(nnx.Module):
    """
    Feed forward layer
    """

    def __init__(
        self,
        d_model: int = 512,
        dim_feedforward: int = 2048,
        rngs=None,
    ):
        super().__init__()
        self.linear1 = nnx.Linear(
            in_features=d_model, out_features=dim_feedforward, rngs=rngs
        )
        self.linear2 = nnx.Linear(
            in_features=dim_feedforward, out_features=d_model, rngs=rngs
        )

    def __call__(self, x):
        x = self.linear1(x)
        x = nnx.gelu(x)
        x = self.linear2(x)

        return x

class BiaisMultiHeadAttnetion(nnx.Module):
    """
    This setup is used to input informations from the graph into the attention layer
    """
    def __init__(self, num_heads: int = 8,
            in_features=512,
            qkv_features=512,
            rngs=None):
        super().__init__()

        self.num_heads = num_heads
        self.in_features = in_features
        self.qkv_features = qkv_features

        # init the layers
        self.linear_qkv = nnx.Linear(
            in_features=in_features, out_features=qkv_features*3, rngs=rngs
        )

        # final linear layer 
        self.linear_last = nnx.Linear(
            in_features=in_features, out_features=in_features, rngs=rngs
        )
    
    def __call__(self, x, edge):
        """
        x is the node information (nb_batch, seq_len, nb_features)
        and edges is (nb_batch, nb_head, seq_len, seq_len)
        """
        
        # first pass with linear_qkv
        x = self.linear_qkv(x)

        # split to go from (nb_batch, seq_len, qkv_features*3) to (nb_batch, nb_head, seq_len, qkv_features // nb_head, 3)  
        x = einops.rearrange(
            x,
            'b s (h f d) -> b h s f d',
            h=self.num_heads,
            d=3
        )

        # 
        query = x[:, :, :, :, 0]
        keys = x[:, :, :, :, 1]
        values = x[:, :, :, :, 2]

        # Compute the dot product between query and keys
        qk = jnp.einsum('b h i f, b h j f -> b h i j', query, keys)

        # adding biais from edges info
        qk = qk + edge

        # Scale the dot product by the square root of the feature dimension
        qk_scaled = qk / jnp.sqrt(self.qkv_features // self.num_heads)

        # Apply softmax to compute attention weights (optional)
        attention_weights = jax.nn.softmax(qk_scaled, axis=-1)

        # Compute the weighted sum of the values using the attention weights
        output = jnp.einsum('b h i j, b h j f -> b h i f', attention_weights, values)

        # Concatenate the outputs from all the heads
        output = einops.rearrange(output, 'b h s f -> b s (h f)')

        return self.linear_last(output)



class TransformerBlock(nnx.Module):
    """
    Transformer block

    1. Layer Norm
    2. Multi-Head Attention
    3. Layer Norm
    4. Feed Forward

    """

    def __init__(
        self,
        d_model: int = 512,
        nhead: int = 8,
        dim_feedforward: int = 2048,
        dropout: float = 0.0,
        layer_norm_eps: float = 1e-5,
        rngs=None,
    ):
        super().__init__()


        # init layernorm
        self.layernorm1 = nnx.LayerNorm(num_features=d_model, rngs=rngs)

        # init multi-head attention
        self.multihead = BiaisMultiHeadAttnetion(
            num_heads=nhead,
            in_features=d_model,
            qkv_features=d_model,
            rngs=rngs,
        )

        # init layernorm
        self.layernorm2 = nnx.LayerNorm(num_features=d_model, rngs=rngs)

        # init feed forward
        self.feedforward = FeedForward(
            d_model=d_model, dim_feedforward=dim_feedforward, rngs=rngs
        )

        self.dropout = nnx.Dropout(dropout, rngs=rngs)

        self.layer_norm_eps = layer_norm_eps

    def __call__(self, x, edge):

        x_forward = self.layernorm1(x)

        x_forward = self.multihead(x_forward, edge)

        x_forward = self.dropout(x_forward)
        x_forward = x + x_forward
        x_forward_second = self.layernorm2(x_forward)
        x_forward_second = self.feedforward(x_forward_second)
        x_forward_second = self.dropout(x_forward_second)
        x_forward_second = x_forward + x_forward_second

        return x_forward_second


class Transformer(nnx.Module):
    """
    Transformer model
    """

    def __init__(
        self,
        d_model: int = 512,
        nhead: int = 8,
        num_decoder_layers: int = 6,
        dim_feedforward: int = 2048,
        dropout: float = 0.,
        # decoder only
        layer_norm_eps: float = 1e-5,
        out_features: int = 64,
        rngs=None,

    ):
        super().__init__()
        self.d_model = d_model
        self.nhead = nhead
        self.num_decoder_layers = num_decoder_layers
        self.dim_feedforward = dim_feedforward
        self.dropout = dropout


        # we setup a stack of transformer blocks
        self.transformer = nnx.List(
            [
                TransformerBlock(
                    d_model=d_model,
                    nhead=nhead,
                    dim_feedforward=dim_feedforward,
                    dropout=dropout,
                    layer_norm_eps=layer_norm_eps,
                    rngs=rngs,
                )
                for _ in range(num_decoder_layers)
            ],
        )

        # now the last layer norm and linear layer
        self.layernorm = nnx.LayerNorm(num_features=d_model, rngs=rngs)
        self.linear = nnx.Linear(
            in_features=d_model, out_features=out_features, rngs=rngs
        )

    def __call__(self, x, edge):

        for i in range(self.num_decoder_layers):
            x = self.transformer[i](x, edge)

        x = self.layernorm(x)
        x = self.linear(x)

        return x

In [31]:
nb_batch = 32
seq_len = 12
qkv_features = 512
nb_head=8

rngs = nnx.Rngs(44)

### test session
model = Transformer(rngs=rngs)

from jax import random
key = random.key(0)

exemple_node = random.normal(key, (nb_batch, seq_len, qkv_features))

edge = random.normal(key,(nb_batch, nb_head, seq_len, seq_len))


In [33]:
model(exemple_node, edge)

(32, 12, 64)

In [9]:
exemple_node

Array([[[-1.8483702 ,  0.18487331,  2.2878232 , ...,  0.846782  ,
          0.84857917, -0.10905278],
        [-0.6393625 , -1.0291516 ,  0.94285446, ..., -0.3374763 ,
         -0.17158249, -0.51815677],
        [-0.62649816,  1.3318139 , -1.9166517 , ...,  0.9280785 ,
         -0.56609195, -2.351592  ],
        ...,
        [-0.50742257,  0.16040553, -1.1405077 , ..., -0.2639426 ,
          0.5241131 ,  0.44146177],
        [ 0.5251477 ,  0.656969  ,  0.33670473, ..., -0.04454711,
          0.39400226, -0.4205947 ],
        [-0.8347021 , -0.9973777 , -1.1542041 , ...,  0.812793  ,
          0.9061063 ,  0.78657585]],

       [[ 0.758117  , -0.25351432, -0.13754742, ...,  0.12224361,
          1.4125918 , -0.7490418 ],
        [-0.6561344 ,  1.6324428 ,  0.52141684, ...,  0.7452573 ,
         -0.03740723, -0.8705038 ],
        [-1.407909  , -0.75454223,  0.5740301 , ..., -0.91798043,
          0.5860514 ,  1.6368006 ],
        ...,
        [-0.06043135, -1.4515245 ,  0.03387675, ...,  