# Industrial Transformers

## Libraries

Based off of [Vanilla Transformers](https://colab.research.google.com/drive/1VAsHQLrCSNb4V_c-mCFdIYfQBXXIQYz0#scrollTo=_QWiUFmzTkXL)


In [3]:
!nvidia-smi

Sat Jun 17 15:19:03 2023       
+-----------------------------------------------------------------------------+
| NVIDIA-SMI 470.57.02    Driver Version: 470.57.02    CUDA Version: 11.8     |
|-------------------------------+----------------------+----------------------+
| GPU  Name        Persistence-M| Bus-Id        Disp.A | Volatile Uncorr. ECC |
| Fan  Temp  Perf  Pwr:Usage/Cap|         Memory-Usage | GPU-Util  Compute M. |
|                               |                      |               MIG M. |
|   0  Tesla T4            Off  | 00000000:00:1E.0 Off |                    0 |
| N/A   32C    P0    26W /  70W |      0MiB / 15109MiB |      0%      Default |
|                               |                      |                  N/A |
+-------------------------------+----------------------+----------------------+
                                                                               
+-----------------------------------------------------------------------------+
| Proces

In [4]:
import math
from numbers import Number
import re
from typing import Any, Callable, Dict, List, Optional, Tuple, Union
from dataclasses import dataclass, field

import jax
import jax.numpy as jnp
from jax import random

import optax
from jax import lax
import flax
from flax import linen as nn

# from flax import optim
from flax import jax_utils
from flax.training import train_state, checkpoints, common_utils

import matplotlib.pyplot as plt
import numpy as np
import polars as pl
import torch
from torch.utils.data import DataLoader, Dataset
from tqdm import tqdm
from functools import partial

ModuleNotFoundError: No module named 'jax'

# Data Preprocessing

- Drop all completely empty columns.
- Make local_* variables strings.
- Scale numeric variables.
- Drop id and date variables.
- Separate strings using.


In [4]:
weather = pl.read_parquet("../data/weather_clean.parquet")
print(weather.shape)
weather.sample(5)

(635664, 25)


x,y,station_name,climate_identifier,province_code,local_year,local_month,local_day,local_hour,temp,temp_flag,dew_point_temp,dew_point_temp_flag,humidex,precip_amount,precip_amount_flag,relative_humidity,relative_humidity_flag,station_pressure,station_pressure_flag,wind_chill,wind_direction,wind_direction_flag,wind_speed,wind_speed_flag
f64,f64,str,str,str,str,str,str,str,f64,str,f64,str,f64,f64,str,f64,str,f64,str,f64,f64,str,f64,str
-111.213333,56.651111,"""FORT MCMURRAY …","""3062696""","""AB""","""2019""","""8""","""23""","""0""",7.9,"""missing""",5.4,"""missing""",,0.0,"""missing""",84.0,"""missing""",96.98,"""missing""",,20.0,"""missing""",11.0,"""missing"""
-112.7675,49.695,"""LETHBRIDGE CDA…","""3033890""","""AB""","""2010""","""11""","""3""","""16""",14.2,"""missing""",-8.0,"""missing""",,,"""missing""",21.0,"""missing""",92.64,"""missing""",,15.0,"""missing""",7.0,"""missing"""
-112.7675,49.695,"""LETHBRIDGE CDA…","""3033890""","""AB""","""2013""","""12""","""9""","""23""",-6.5,"""missing""",-12.3,"""missing""",,,"""missing""",63.0,"""missing""",90.56,"""missing""",-15.0,27.0,"""missing""",33.0,"""missing"""
-111.213333,56.651111,"""FORT MCMURRAY …","""3062696""","""AB""","""2013""","""9""","""25""","""3""",2.6,"""missing""",0.2,"""missing""",,,"""missing""",84.0,"""missing""",96.64,"""missing""",,24.0,"""missing""",11.0,"""missing"""
-114.6825,51.778056,"""SUNDRE A""","""3026KNQ""","""AB""","""2020""","""4""","""21""","""5""",-1.4,"""missing""",-4.4,"""missing""",,0.0,"""missing""",80.0,"""missing""",88.23,"""missing""",-4.0,23.0,"""missing""",6.0,"""missing"""


In [5]:
energy = pl.read_parquet("../data/energy_clean.parquet")
print(energy.shape)
energy.sample(5)

(111743, 10)


local_year,local_month,local_day,local_hour,total_energy_no_imports,total_imports,total_exports,actual_pool_price,actual_ail,day_ahead_pool_price
str,str,str,str,f64,f64,f64,f64,i64,f64
"""2014""","""9""","""27""","""17""",6671.968413,275.0,0.0,47.48,8835,44.37
"""2014""","""7""","""19""","""6""",6112.240516,0.0,0.0,12.44,8097,13.92
"""2014""","""4""","""25""","""8""",6912.162844,350.0,0.0,55.3,8987,52.99
"""2015""","""11""","""13""","""21""",7441.938827,0.0,104.0,17.53,9477,17.13
"""2014""","""2""","""10""","""13""",7645.64203,650.0,0.0,90.48,10449,91.16


Scale numeric variables and remove spaces from strings.


In [6]:
def scale_numeric(df):
    for col in df.columns:
        if df[col].dtype == pl.Float64 or df[col].dtype == pl.Int64:
            df = df.with_columns(
                ((pl.col(col) - pl.col(col).mean()) / pl.col(col).std()).alias(col)
            )  # .select(pl.col(["dew_point_temp", "NewCOL"]))
    return df


weather = scale_numeric(weather)
energy = scale_numeric(energy)

In [7]:
def make_lower_remove_special_chars(df):
    df = df.with_columns(
        pl.col(pl.Utf8).str.to_lowercase().str.replace_all("[^a-zA-Z0-9]", " ")
    )
    return df


weather = make_lower_remove_special_chars(weather)
energy = make_lower_remove_special_chars(energy)

In [8]:
def get_unique_utf8_values(df):
    arr = np.array([])
    for col in df.select(pl.col(pl.Utf8)).columns:
        arr = np.append(arr, df[col].unique().to_numpy())

    return np.unique(arr)


weather_val_tokens = get_unique_utf8_values(weather)
energy_val_tokens = get_unique_utf8_values(energy)
energy_val_tokens

array(['0', '1', '10', '11', '12', '13', '14', '15', '16', '17', '18',
       '19', '2', '20', '2010', '2011', '2012', '2013', '2014', '2015',
       '2016', '2017', '2018', '2019', '2020', '2021', '2022', '21', '22',
       '23', '24', '25', '26', '27', '28', '29', '3', '30', '31', '4',
       '5', '6', '7', '8', '9'], dtype=object)

In [9]:
def get_col_tokens(df):
    tokens = []
    for col_name in df.columns:
        sub_strs = re.split(r"[^a-zA-Z0-9]", col_name)
        tokens.extend(sub_strs)
    return np.unique(np.array(tokens))


weather_col_tokens = get_col_tokens(weather)
energy_col_tokens = get_col_tokens(energy)
energy_col_tokens

array(['actual', 'ahead', 'ail', 'day', 'energy', 'exports', 'hour',
       'imports', 'local', 'month', 'no', 'pool', 'price', 'total',
       'year'], dtype='<U7')

In [10]:
special_tokens = np.array(
    [
        "missing",
        "<batch-start>",
        "<batch-end>",
        "<pad>",
        "<unk>",
        ":",
        ",",
        "<row-start>",
        "<row-end>",
    ]
)
tokens = np.unique(
    np.concatenate(
        (
            weather_val_tokens,
            energy_val_tokens,
            weather_col_tokens,
            energy_col_tokens,
            special_tokens,
        )
    )
)
tokens

array([',', '0', '1', '10', '11', '12', '13', '14', '15', '16', '17',
       '18', '19', '2', '20', '2010', '2011', '2012', '2013', '2014',
       '2015', '2016', '2017', '2018', '2019', '2020', '2021', '2022',
       '21', '22', '23', '24', '25', '26', '27', '28', '29', '3', '30',
       '3012206', '3026knq', '3031094', '3033890', '3035208', '3062696',
       '31', '4', '5', '6', '7', '8', '9', ':', '<batch-end>',
       '<batch-start>', '<pad>', '<row-end>', '<row-start>', '<unk>',
       'ab', 'actual', 'ahead', 'ail', 'amount', 'calgary int l cs',
       'chill', 'climate', 'code', 'day', 'dew', 'direction',
       'edmonton international cs', 'energy', 'exports', 'flag',
       'fort mcmurray cs', 'hour', 'humidex', 'humidity', 'identifier',
       'imports', 'lethbridge cda', 'local', 'm', 'missing', 'month',
       'name', 'no', 'pincher creek climate', 'point', 'pool', 'precip',
       'pressure', 'price', 'province', 'relative', 'speed', 'station',
       'sundre a', 'temp', '

In [11]:
@dataclass
class StringNumeric:
    value: Union[str, float]
    # all_tokens: np.array
    is_numeric: bool = field(default=None, repr=True)
    embedding_idx: int = field(default=None, repr=True)

    def __post_init__(self):
        if isinstance(self.value, str):
            self.is_numeric = False
        else:
            self.is_numeric = True
            self.embedding_idx = 0

    def gen_embed_idx(self, tokens: np.array):
        if not self.is_numeric:
            try:
                self.embedding_idx = np.where(tokens == self.value)[0][0] + 1
            except IndexError:
                self.embedding_idx = np.where(tokens == "<unk>")[0][0] + 1


x = StringNumeric(value="climate")
# xx = StringNumeric(value="climate", tokens=tokens)
print(x)
y = StringNumeric(value=1.0)
print(y)
z = StringNumeric(value="SomeRandomString")
print(z)
x.gen_embed_idx(tokens)
print(x)
# print(StringNumeric(value=1.0, all_tokens=tokens))

StringNumeric(value='climate', is_numeric=False, embedding_idx=None)
StringNumeric(value=1.0, is_numeric=True, embedding_idx=0)
StringNumeric(value='SomeRandomString', is_numeric=False, embedding_idx=None)
StringNumeric(value='climate', is_numeric=False, embedding_idx=67)


In [12]:
tokens

array([',', '0', '1', '10', '11', '12', '13', '14', '15', '16', '17',
       '18', '19', '2', '20', '2010', '2011', '2012', '2013', '2014',
       '2015', '2016', '2017', '2018', '2019', '2020', '2021', '2022',
       '21', '22', '23', '24', '25', '26', '27', '28', '29', '3', '30',
       '3012206', '3026knq', '3031094', '3033890', '3035208', '3062696',
       '31', '4', '5', '6', '7', '8', '9', ':', '<batch-end>',
       '<batch-start>', '<pad>', '<row-end>', '<row-start>', '<unk>',
       'ab', 'actual', 'ahead', 'ail', 'amount', 'calgary int l cs',
       'chill', 'climate', 'code', 'day', 'dew', 'direction',
       'edmonton international cs', 'energy', 'exports', 'flag',
       'fort mcmurray cs', 'hour', 'humidex', 'humidity', 'identifier',
       'imports', 'lethbridge cda', 'local', 'm', 'missing', 'month',
       'name', 'no', 'pincher creek climate', 'point', 'pool', 'precip',
       'pressure', 'price', 'province', 'relative', 'speed', 'station',
       'sundre a', 'temp', '

In [13]:
class TabularDataset(Dataset):
    # def __init__(self, df: pl.DataFrame, vocab_dict: Dict, m_dim: int) -> Dataset:
    def __init__(
        self,
        df: pl.DataFrame,
        vocab,
        shuffle_cols=False,
        n_rows=None,
        max_seq_length=512,
    ) -> Dataset:
        self.df = df
        self.vocab = vocab
        self.shuffle_cols = shuffle_cols
        self.n_rows = n_rows
        self.max_seq_length = max_seq_length
        # self.vocab_dict = vocab_dict
        # self.embedding = nn.Embedding(len(self.string_vocab), m_dim)
        # Numeric Scale

        # self.col_vocab = self.df.columns

    def __len__(self):
        """Returns the number of sequences in the dataset."""
        length = self.df.shape[0] // self.n_rows
        return length

    def __getitem__(self, idx):
        """Returns a tuple of (input, target) at the given index."""
        batch = self.batch(idx)
        start = StringNumeric("<batch-start>")
        start.gen_embed_idx(self.vocab)
        end = StringNumeric("<batch-end>")
        end.gen_embed_idx(self.vocab)
        batch = self.padder(batch)
        batch = [start] + batch + [end]
        return batch

    def batch(self, idx):
        """Returns a batch from splitter from the starting index to the start
        index + n_rows"""
        batch = []
        for i in range(idx, idx + self.n_rows):
            row = self.df[i]
            row = self.splitter(row)
            batch.extend(row)

        return batch

    def padder(self, batch: List[StringNumeric]):
        diff = self.max_seq_length - len(batch)  # -2 for start and end
        if diff > 0:
            pad = StringNumeric("<pad>")
            pad.gen_embed_idx(self.vocab)
            batch.extend([pad] * diff)
        elif diff < 0:
            batch = batch[: self.max_seq_length - 1]
            # add warning
            new_end = StringNumeric("<batch-end>")
            new_end.gen_embed_idx(self.vocab)
            batch.append(new_end)
            print("Batch too long, truncating")
            Warning("Batch too long, truncating")
        return batch

    def splitter(self, row: pl.DataFrame) -> List[Union[str, float, None]]:
        vals = ["<row-start>"]
        cols = row.columns
        if self.shuffle_cols:
            np.random.shuffle(cols)

        for col in cols:
            value = row[col][0]
            col = col.split("_")
            vals.extend(col)
            vals.append(":")
            if isinstance(value, Number):
                vals.append(value)
            elif value is None:
                vals.append("missing")
                # Nones are only for numeric columns, others are "None"
            elif isinstance(value, str):
                vals.extend(value.split(" "))
            else:
                raise ValueError("Unknown type")
            vals.append(",")
        vals.append("<row-end>")

        vals = [StringNumeric(value=val) for val in vals]
        for val in vals:
            val.gen_embed_idx(self.vocab)

        return vals


weather_ds = TabularDataset(weather, tokens, shuffle_cols=False, n_rows=2)
energy_ds = TabularDataset(energy, tokens, shuffle_cols=False, n_rows=2)
print(weather_ds[0][:10])

[StringNumeric(value='<batch-start>', is_numeric=False, embedding_idx=55), StringNumeric(value='<row-start>', is_numeric=False, embedding_idx=58), StringNumeric(value='x', is_numeric=False, embedding_idx=103), StringNumeric(value=':', is_numeric=False, embedding_idx=53), StringNumeric(value=-0.551099305737714, is_numeric=True, embedding_idx=0), StringNumeric(value=',', is_numeric=False, embedding_idx=1), StringNumeric(value='y', is_numeric=False, embedding_idx=104), StringNumeric(value=':', is_numeric=False, embedding_idx=53), StringNumeric(value=-0.37817406811183396, is_numeric=True, embedding_idx=0), StringNumeric(value=',', is_numeric=False, embedding_idx=1)]


In [14]:
@dataclass
class Config:
    embed_dim: int = 256
    n_heads: int = 8
    kvq_dim: int = embed_dim // n_heads
    ff_dim: int = 512
    p_drop: float = 0.2
    encoder_vocab_size: int = weather_ds.vocab.shape[0]
    decoder_vocab_size: int = energy_ds.vocab.shape[0]


config = Config()

In [15]:
# jnp.arange(10) * jnp.arange(10).T#[:, None]
# jnp.einsum("i,j->ij", jnp.arange(10), jnp.arange(10))
jnp.arange(10)[jnp.newaxis, :].ndim  # [:, None]



2

In [16]:
class PositionalEncoding(nn.Module):
    """Adds positional encoding to the input."""

    config: Config

    @nn.compact
    def __call__(self, x: jnp.array) -> jnp.array:
        assert x.ndim == 3, "Input must have rank 3"
        config = self.config

        batch_size, seq_len, _ = x.shape
        pe = jnp.zeros(batch_size, seq_len, config.embed_dim)

        position = jnp.arrange(seq_len)[jnp.newaxis, :]
        div_term = jnp.exp(
            jnp.arange(0, config.embed_dim, 2) * (-jnp.log(10000.0) / config.embed_dim)
        )
        radians = jnp.einsum(
            "ij,kl->jl", position, div_term
        )  # just a matrix multiplication
        pe = pe.at[:, :, 0::2].set(jnp.sin(radians))
        pe = pe.at[:, :, 1::2].set(jnp.cos(radians))
        return (x + pe).astype(jnp.float32)

In [17]:
class MultiHeadAttention(nn.Module):
    """Multi Headed Dot Product Attention"""

    config: Config

    @nn.compact
    def __call__(
        self,
        q: jnp.array,
        k: jnp.array,
        v: jnp.array,
        mask: jnp.array = None,
        dropout: float = 0.0,
    ) -> jnp.array:
        config = self.config

        assert q.ndim == k.ndim == v.ndim == 3, "Input must have rank 3"
        assert q.shape[0] == k.shape[0] == v.shape[0], "Batch size must be equal"
        assert (
            q.shape[2] == k.shape[2] == v.shape[2]
        ), "Embedding dimension must be equal"

        batch_size, seq_len, embed_dim = q.shape
        assert (
            embed_dim % config.num_heads == 0
        ), "Embedding dimension must be divisible by number of heads"

        q = nn.Dense(config.embed_dim, name="DenseQ")(q)
        k = nn.Dense(config.embed_dim, name="DenseK")(k)
        v = nn.Dense(config.embed_dim, name="DenseV")(v)

        q = q.reshape(-1, seq_len, config.n_heads, config.kvq_dim)
        k = k.reshape(-1, seq_len, config.n_heads, config.kvq_dim)
        v = v.reshape(-1, seq_len, config.n_heads, config.kvq_dim)

        attention = jnp.einsum("...qhd,...khd->...hqk", q, k) / jnp.sqrt(config.kvq_dim)

        if mask is not None:
            attention = jnp.where(mask, attention, -jnp.inf)

        attention = nn.softmax(attention, axis=-1)
        values = jnp.einsum("...hqk,...khd->...qhd", attention, v)
        values = values.reshape(-1, seq_len, embed_dim)
        out = nn.Dense(embed_dim)(values)

        return out, attention

In [18]:
class FeedForward(nn.Module):
    """Feed Forward Neural Network"""

    config: Config

    @nn.compact
    def __call__(self, x: jnp.array, deterministic: bool) -> jnp.array:
        config = self.config
        x = nn.Dense(config.ff_dim * 4, name="FFDense1")(x)
        x = nn.relu(x)
        x = nn.Dropout(config.p_drop)(x, deterministic=deterministic)
        x = nn.Dense(config.ff_dim, name="FFDense2")(x)
        x = nn.Dropout(config.p_drop)(x, deterministic=deterministic)

        return x

In [19]:
class TransformerEncoderLayer(nn.Module):
    """Transformer Encoder Layer"""

    config: Config

    @nn.compact
    def __call__(self, x: jnp.array, mask: jnp.array, deterministic: bool) -> jnp.array:
        config = self.config
        res = x
        x, attention = MultiHeadAttention(config)(x, x, x, mask)
        x = nn.Dropout(config.p_drop)(x, deterministic=deterministic)
        x = nn.LayerNorm(name="LayerNorm1")(x + res)
        res = x
        x = FeedForward(config, name="FeedForward")(x, deterministic=deterministic)
        x = nn.LayerNorm(name="LayerNorm2")(x + res)

        return x, attention

In [20]:
class TransformerDecoderLayer(nn.Module):
    config: Config
    """Transformer Decoder Layer"""

    @nn.compact
    def __call__(
        self,
        x: jnp.array,
        memory: jnp.array,
        decoder_mask: jnp.array,
        encoder_decoder_mask: jnp.array,
        deterministic: bool,
    ) -> Tuple[jnp.array, jnp.array, jnp.array]:
        config = self.config
        res = x
        x, attention = MultiHeadAttention(config)(x, x, x, decoder_mask)
        x = nn.Dropout(config.p_drop)(x, deterministic=deterministic)
        x = nn.LayerNorm(name="LayerNorm1")(x + res)
        res = x
        x, attention = MultiHeadAttention(config)(
            x, memory, memory, encoder_decoder_mask
        )
        x = nn.Dropout(config.p_drop)(x, deterministic=deterministic)
        x = nn.LayerNorm(name="LayerNorm2")(x + res)
        return x, attention

In [21]:
x = jnp.array([1, 1, 3, 3, 5, 6, 7, 8, 9, 10])

emb = nn.Embed(num_embeddings=10, features=5)
emb_variables = emb.init(random.PRNGKey(0), x)
emb_output = emb.apply(emb_variables, x)
emb_output

Array([[-0.440659  ,  0.23434746,  0.44432253,  0.35004553, -1.0721016 ],
       [-0.440659  ,  0.23434746,  0.44432253,  0.35004553, -1.0721016 ],
       [-0.5751345 , -0.31515005, -0.7112598 , -0.22740227,  0.52951103],
       [-0.5751345 , -0.31515005, -0.7112598 , -0.22740227,  0.52951103],
       [-0.1656743 ,  0.0334887 ,  0.7145505 , -0.69273764, -0.7260953 ],
       [-0.12573542, -0.35354394,  0.10587429,  0.05312492,  0.4437487 ],
       [-0.04921824, -0.10327773, -0.13198015, -0.4364049 , -0.38484678],
       [ 0.31554   ,  0.09454814,  0.4258944 , -0.27779824,  0.07495691],
       [ 0.71339667, -0.04502202, -0.35073858,  0.59638804,  0.05645175],
       [        nan,         nan,         nan,         nan,         nan]],      dtype=float32)

In [22]:
emb_output

Array([[-0.440659  ,  0.23434746,  0.44432253,  0.35004553, -1.0721016 ],
       [-0.440659  ,  0.23434746,  0.44432253,  0.35004553, -1.0721016 ],
       [-0.5751345 , -0.31515005, -0.7112598 , -0.22740227,  0.52951103],
       [-0.5751345 , -0.31515005, -0.7112598 , -0.22740227,  0.52951103],
       [-0.1656743 ,  0.0334887 ,  0.7145505 , -0.69273764, -0.7260953 ],
       [-0.12573542, -0.35354394,  0.10587429,  0.05312492,  0.4437487 ],
       [-0.04921824, -0.10327773, -0.13198015, -0.4364049 , -0.38484678],
       [ 0.31554   ,  0.09454814,  0.4258944 , -0.27779824,  0.07495691],
       [ 0.71339667, -0.04502202, -0.35073858,  0.59638804,  0.05645175],
       [        nan,         nan,         nan,         nan,         nan]],      dtype=float32)

In [23]:
y = jnp.arange(10)
print(y)
y = y.at[2].set(100)
print(y)

[0 1 2 3 4 5 6 7 8 9]
[  0   1 100   3   4   5   6   7   8   9]


In [30]:
class FloatEmbedding(nn.Module):
    # config: Config
    """Embedding lookup for token ids."""

    @nn.compact
    def __call__(self, x: list[StringNumeric]) -> jnp.array:
        config = Config()
        embed = nn.Embed(
            config.vocab_size,
            config.embed_dim,
            embedding_init=jax.nn.initializers.normal(stddev=config.embed_dim**-0.5),
            name="FloatEmbedding",
        )
        embeddings = jnp.zeros((len(x), config.embed_dim))
        # return embeddings
        for i, sn in enumerate(x):
            if sn.is_numeric:
                arr = jnp.zeros(config.embed_dim)
                arr = arr.at[0].set(sn.value)
                embeddings = embeddings.at[i].set(arr)
            else:
                embeddings = embeddings.at[i].set(embed(sn.embedding_idx))

        return embeddings

In [31]:
class TransformerEncoder(nn.Module):
    config: Config
    """Transformer Encoder"""

    @nn.compact
    def __call__(
        self,
        x: jnp.array,
        mask: jnp.array,
        deterministic: bool,
        return_attention: bool = False,
    ) -> Union[jnp.array, Tuple[jnp.array, List[jnp.array]]]:
        config = self.config
        x = FloatEmbedding(config)(x)
        x = PositionalEncoding(config)(x)
        x = nn.Dropout(config.p_drop)(x, deterministic=deterministic)
        attention_list = []
        for i in range(config.n_layers):
            x, attention = TransformerEncoderLayer(config, name=f"EncoderLayer_{i}")(
                x, mask, deterministic
            )
            attention_list.append(attention)

        x = nn.LayerNorm(name="LayerNorm")(x)
        if return_attention:
            return x, attention_list

        return x

In [32]:
class TransformerDecoder(nn.Module):
    config: Config
    """Transformer Decoder"""

    @nn.compact
    def __call__(
        self,
        x: jnp.array,
        memory: jnp.array,
        decoder_mask: jnp.array,
        encoder_decoder_mask: jnp.array,
        deterministic: bool,
        return_attention: bool = False,
    ) -> Union[jnp.array, Tuple[jnp.array, List[jnp.array]]]:
        config = self.config
        x = FloatEmbedding(config)(x)
        x = PositionalEncoding(config)(x)
        x = nn.Dropout(config.p_drop)(x, deterministic=deterministic)
        self_attention_list, src_attention_list = [], []
        for i in range(config.n_layers):
            x, self_attention, src_attention = TransformerDecoderLayer(
                config, name=f"DecoderLayer_{i}"
            )(
                x,
                memory,
                decoder_mask,
                encoder_decoder_mask,
                deterministic,
            )
            self_attention_list.append(self_attention)
            src_attention_list.append(src_attention)

        x = nn.LayerNorm(name="LayerNorm")(x)
        x = nn.Dense(config.vocab_size, name="Dense")(x)
        if return_attention:
            return x, self_attention_list, src_attention_list

        return x

In [33]:
class Transformer(nn.Module):
    config: Config
    """Transformer"""

    def setup(self):
        self.Encoder = TransformerEncoder(self.config)
        self.Decoder = TransformerDecoder(self.config)

        def encode(self, src: TabularDataset, train: bool = True) -> jnp.array:
            config = self.config

            encoder_mask = nn.make_attention_mask(jnp.ones_like(src), src, dtype="bool")

        def __call__(
            self, src: jnp.array, tgt: jnp.array, train: bool = False
        ) -> jnp.array:
            config = self.config

            memory = self.encode(src, train=train)
            logits = self.decode(trg, src, memory, train=train)

        def encode(
            self, src: njp.array, train: bool = False, return_attn: bool = False
        ) -> Union[jnp.array, Tuple[jnp.array, List[jnp.array]]]:
            encoder_mask = nn.make_attention_mask(
                jnp.ones_like(src), src != config.pad_idx, dtype=bool
            )

            if return_attn:
                memory, encoder_attention_list = self.Encoder(
                    src, encoder_mask, not train, return_attn=return_attn
                )
                return memory, encoder_attention_list
            else:
                memory = self.Encoder(
                    src, encoder_mask, not train, return_attn=return_attn
                )
                return memory

        def decode(
            self,
            trg: jnp.array,
            src: jnp.array,  # only for making mask
            memory: jnp.array,
            train: bool = False,
            return_attn: bool = False,
        ) -> Union[jnp.array, Tuple[jnp.array, List[jnp.array], List[jnp.array]]]:
            """Decode targets with Transformer Decoder.

            Args:
              trg: targets of shape [batch_size, trg_length].
              src: sources of shape [batch_size, src_length].
              memory: encoded sources from Transformer Encoder of shape [batch_size, src_length, features].
              train: To train, it should be set True, otherwise False.
              return_attn: if true, returns Self-Attention and Source-Target-Attention matrixes.

            Returns:
              If return_attn is True,
                the logits of shape [batch_size, trg_length, target_vocab_size],
                list of attention matrix in Self-Attention for the number of layers,
                and list of attention matrix in Source-Target-Attention for the number of layers.
              else,
                the logits.
            """

            assert trg.ndim == 2
            config = self.config

            decoder_mask = nn.combine_masks(
                nn.make_attention_mask(
                    jnp.ones_like(trg), trg != config.pad_idx, dtype=bool
                ),
                nn.make_causal_mask(trg, dtype=bool),
            )  # [Batch, 1, SeqLen_q, SeqLen_k]
            encoder_decoder_mask = nn.make_attention_mask(
                jnp.ones_like(trg), src != config.pad_idx, dtype=bool
            )  # [Batch, 1, SeqLen_q, SeqLen_k]

            if return_attn:
                logits, decoder_attention_list, src_trg_attention_list = self.Decoder(
                    memory,
                    trg,
                    decoder_mask,
                    encoder_decoder_mask,
                    not train,
                    return_attn=return_attn,
                )
                return logits, decoder_attention_list, src_trg_attention_list
            else:
                logits = self.Decoder(
                    memory,
                    trg,
                    decoder_mask,
                    encoder_decoder_mask,
                    not train,
                    return_attn=return_attn,
                )
                return logits