<https://github.com/PolymathicAI/xVal>


In [1]:
import os

os.environ["XLA_PYTHON_CLIENT_MEM_FRACTION"] = "0.95"

In [2]:
import jax.numpy as jnp  # Oddly works in colab to set gpu

arr = jnp.array([1, 2, 3])
arr.devices()

{cuda(id=0)}

In [3]:
import icecream
from icecream import ic

icecream.install()
ic_disable = True
if ic_disable:
    ic.disable()
ic.configureOutput(includeContext=True, contextAbsPath=True)

In [4]:
import os
import ast

from datetime import datetime as dt
from torch.utils.data import Dataset
from torch.utils.data import DataLoader

import hephaestus as hp
import jax
import jax.numpy as jnp
import numpy as np
import optax
import pandas as pd
from flax.training import train_state
from icecream import ic
from jax import random
from flax import struct
from jax.tree_util import tree_flatten
from torch.utils.tensorboard import SummaryWriter
from tqdm.notebook import tqdm, trange

pd.options.mode.copy_on_write = True



In [5]:
def line2df(line, idx):
    data_rows = []
    line = ast.literal_eval(line)
    for i, time_step in enumerate(line["data"]):
        row = {"time_step": i}
        # Add position data for each planet
        for j, position in enumerate(time_step):
            row[f"planet{j}_x"] = position[0]
            row[f"planet{j}_y"] = position[1]
        data_rows.append(row)

    df = pd.DataFrame(data_rows)
    description = line.pop("description")
    step_size = description.pop("stepsize")
    for k, v in description.items():
        for k_prop, v_prop in v.items():
            df[f"{k}_{k_prop}"] = v_prop
    df["time_step"] = df["time_step"] * step_size
    df.insert(0, "idx", idx)

    return df

In [6]:
files = os.listdir("data")
if "planets.parquet" not in files:
    with open("data/planets.data") as f:
        data = f.read().splitlines()

        dfs = []
        for idx, line in enumerate(tqdm(data)):
            dfs.append(line2df(line, idx))
        print("Concatenating dfs...")
        df = pd.concat(dfs)
    df.to_parquet("data/planets.parquet")
else:
    df = pd.read_parquet("data/planets.parquet")

In [7]:
# Get min, mean, and max number of time steps
df.groupby("idx").count().time_step.agg(["min", "mean", "max"])

min     30.000000
mean    44.511656
max     59.000000
Name: time_step, dtype: float64

In [8]:
class SimpleDS(Dataset):
    def __init__(self, df):
        # Add nan padding to make sure all sequences are the same length
        # use the idx column to group by
        self.max_seq_len = df.groupby("idx").count().time_step.max()

        self.df = df
        self.batch_size = self.max_seq_len

        self.special_tokens = ["[PAD]", "[NUMERIC_MASK]", "[MASK]"]
        self.cat_mask = "[MASK]"
        self.numeric_mask = "[NUMERIC_MASK]"

        self.col_tokens = [col_name for col_name in df.columns if col_name != "idx"]

        self.tokens = self.special_tokens + self.col_tokens

        self.token_dict = {token: i for i, token in enumerate(self.tokens)}
        self.token_decoder_dict = {i: token for i, token in enumerate(self.tokens)}
        self.n_tokens = len(self.tokens)
        self.numeric_indices = jnp.array(
            [self.tokens.index(i) for i in self.col_tokens]
        )

        self.numeric_mask_token = self.tokens.index(self.numeric_mask)

    def __len__(self):
        return df.idx.max() + 1  # probably should be max idx + 1 thanks

    def __getitem__(self, set_idx):
        batch = self.df.loc[
            df.idx == set_idx, [col for col in self.df.columns if col != "idx"]
        ]
        batch = np.array(batch.values)
        # Add padding
        batch_len, n_cols = batch.shape
        pad_len = self.max_seq_len - batch_len
        padding = np.full((pad_len, n_cols), jnp.nan)
        batch = np.concatenate([batch, padding], axis=0)
        return batch

In [9]:
np.full((3, 3), np.nan)

array([[nan, nan, nan],
       [nan, nan, nan],
       [nan, nan, nan]])

In [10]:
train_ds = SimpleDS(df)

In [11]:
time_series_regressor = hp.simple_time_series.SimplePred(train_ds, d_model=64 * 4)

In [12]:
time_series_regressor.d_model

256

In [13]:
train_ds[0]

array([[ 0.        ,  1.56006022, -0.85443699, ...,         nan,
                nan,         nan],
       [ 0.46511628,  1.68985799, -0.5143588 , ...,         nan,
                nan,         nan],
       [ 0.93023256,  1.75358875, -0.15420858, ...,         nan,
                nan,         nan],
       ...,
       [        nan,         nan,         nan, ...,         nan,
                nan,         nan],
       [        nan,         nan,         nan, ...,         nan,
                nan,         nan],
       [        nan,         nan,         nan, ...,         nan,
                nan,         nan]])

In [14]:
def make_batch(ds: SimpleDS, start: int, length: int):
    data = []
    for i in range(start, length + start):
        data.append(ds[i])

    return jnp.array(data)


batch = make_batch(train_ds, 0, 4)

In [15]:
vars = time_series_regressor.init(random.PRNGKey(0), batch)

x = time_series_regressor.apply(vars, batch)

ic| simple_time_series.py:188 in __call__()
    col_embeddings.shape: (26, 256)
    numeric_inputs.shape: (4, 59, 26)
ic| simple_time_series.py:192 in __call__()
    repeated_numeric_indices.shape: (59, 26)
ic| simple_time_series.py:202 in __call__()
    numeric_inputs.shape: (4, 59, 26)
    numeric_col_embeddings.shape: (59, 26, 256)
ic| simple_time_series.py:206 in __call__()
    "here!!!!!!": 'here!!!!!!'
    numeric_inputs.shape: (4, 59, 26)
    numeric_col_embeddings.shape: (59, 26, 256)
    nan_mask.shape: (4, 59, 26)
    numeric_mat_mull.shape: (4, 59, 26, 256)
ic| simple_time_series.py:220 in __call__()
    f"Nan values in out: {jnp.isnan(out).any()}": 'Nan values in out: False'
ic| simple_time_series.py:225 in __call__()
    out.shape: (4, 59, 26, 272)
ic| simple_time_series.py:226 in __call__()
    f"Nan values in out positional: {jnp.isnan(out).any()}": 'Nan values in out positional: False'
ic| simple_time_series.py:227 in __call__()- 'Starting Attention'
ic| simple_time_ser

In [16]:
1 / 0

ZeroDivisionError: division by zero

In [None]:
x

Array([[[-0.99787635, -1.0830878 , -1.0205737 , ..., -1.1147583 ,
         -1.1147583 , -1.1147583 ],
        [-0.6347571 , -1.0432668 , -0.9717052 , ..., -0.9100214 ,
         -0.9100214 , -0.9100214 ],
        [-0.43598816, -1.0155857 , -0.74097013, ..., -0.92857766,
         -0.92857766, -0.92857766],
        ...,
        [-1.5675948 , -1.5675948 , -1.5675948 , ..., -1.5675948 ,
         -1.5675948 , -1.5675948 ],
        [-1.5146824 , -1.5146824 , -1.5146824 , ..., -1.5146824 ,
         -1.5146824 , -1.5146824 ],
        [-1.324331  , -1.324331  , -1.324331  , ..., -1.324331  ,
         -1.324331  , -1.324331  ]],

       [[-0.99787635, -0.91905284, -0.6940967 , ..., -1.1147583 ,
         -1.1147583 , -1.1147583 ],
        [-0.656036  , -0.6604141 , -0.38698754, ..., -0.9100214 ,
         -0.9100214 , -0.9100214 ],
        [-0.47413844, -0.59235746, -0.11756715, ..., -0.92857766,
         -0.92857766, -0.92857766],
        ...,
        [-1.5675948 , -1.5675948 , -1.5675948 , ..., -

In [None]:
batch.shape

(4, 59, 26)

In [None]:
def calculate_memory_footprint(params):
    """Calculate total memory footprint of JAX model parameters."""
    total_bytes = 0
    # Flatten the parameter tree structure into a list of arrays
    flat_params, _ = tree_flatten(params)
    for param in flat_params:
        # Calculate bytes: number of elements * size of each element
        bytes_per_param = param.size * param.dtype.itemsize
        total_bytes += bytes_per_param
    return total_bytes

In [None]:
def create_causal_mask(tensor: jnp.ndarray):
    """Create a causal mask to mask out future values."""
    mask = jnp.tril(jnp.ones((tensor.shape[0], tensor.shape[1])))
    return mask


def create_padding_mask(tensor: jnp.ndarray):
    """Create a padding mask to mask out padded values."""
    mask = jnp.isnan(tensor)
    return mask


def mask_array(tensor: jnp.array):
    """Create a mask for the tensor"""
    causal_mask = create_causal_mask(tensor)
    padding_mask = create_padding_mask(tensor)
    mask = jnp.logical_or(causal_mask, padding_mask)
    return mask

In [None]:
jnp.array(train_ds[0]).shape

(59, 26)

In [None]:
mts_root_key = random.PRNGKey(44)
mts_main_key, ts_params_key, ts_data_key = random.split(mts_root_key, 3)


def clip_gradients(gradients, max_norm):
    total_norm = jnp.sqrt(sum(jnp.sum(jnp.square(grad)) for grad in gradients.values()))
    scale = max_norm / (total_norm + 1e-6)
    clipped_gradients = jax.tree_map(
        lambda grad: jnp.where(total_norm > max_norm, grad * scale, grad), gradients
    )
    return clipped_gradients


# @jax.jit
def calculate_loss(params, state, inputs, dataset: SimpleDS):
    out = state.apply_fn(
        {"params": params},
        hp.mask_tensor(inputs, dataset, prng_key=ts_data_key),
    )

    # Create mask for nan inputs
    nan_mask = jnp.isnan(inputs)
    inputs = jnp.where(nan_mask, jnp.zeros_like(inputs), inputs)
    out = jnp.where(nan_mask, jnp.zeros_like(out), out)

    raw_loss = optax.squared_error(out, inputs)
    masked_loss = jnp.where(nan_mask, 0.0, raw_loss)
    loss = masked_loss.sum() / (~nan_mask).sum()

    return loss


@jax.jit
def eval_step(state: train_state.TrainState, batch):
    def loss_fn(params):
        return calculate_loss(params, state, batch, train_ds)

    # (loss, individual_losses), grad = grad_fn(state.params)
    loss, loss_dict = loss_fn(state.params)
    return loss, loss_dict


@jax.jit
def train_step(state: train_state.TrainState, batch):
    def loss_fn(params):
        return calculate_loss(params, state, batch, train_ds)

    grad_fn = jax.value_and_grad(loss_fn)

    # (loss, individual_losses), grad = grad_fn(state.params)
    loss, grad = grad_fn(state.params)
    # grad = replace_nans(grad)
    # grad = clip_gradients(grad, 1.0)
    state = state.apply_gradients(grads=grad)

    return state, loss


def create_train_state(model, prng, batch, lr):
    params = model.init(prng, batch)
    # optimizer = optax.chain(optax.adam(lr))
    optimizer = optax.chain(optax.clip_by_global_norm(0.4), optax.adam(lr))
    # optimizer_state = optimizer.init(params)
    return train_state.TrainState.create(
        apply_fn=model.apply,
        params=params["params"],
        tx=optimizer,
        # tx_state=optimizer_state,
    )


batch_size = 2
# batch = train_ds[0]

state = create_train_state(time_series_regressor, mts_main_key, batch, 0.0001)

ic| simple_time_series.py:188 in __call__()
    col_embeddings.shape: (26, 256)
    numeric_inputs.shape: (4, 59, 26)
ic| simple_time_series.py:192 in __call__()
    repeated_numeric_indices.shape: (59, 26)
ic| simple_time_series.py:202 in __call__()
    numeric_inputs.shape: (4, 59, 26)
    numeric_col_embeddings.shape: (59, 26, 256)
ic| simple_time_series.py:206 in __call__()
    numeric_inputs.shape: (4, 59, 26)
    numeric_col_embeddings.shape: (59, 26, 256)
    nan_mask.shape: (4, 59, 26)
ic| simple_time_series.py:214 in __call__()
    f"Nan values in out: {jnp.isnan(out).any()}": 'Nan values in out: False'
ic| simple_time_series.py:219 in __call__()
    out.shape: (4, 59, 26, 272)
ic| simple_time_series.py:220 in __call__()
    f"Nan values in out positional: {jnp.isnan(out).any()}": 'Nan values in out positional: False'
ic| simple_time_series.py:221 in __call__()- 'Starting Attention'
ic| simple_time_series.py:222 in __call__()
    out.shape: (4, 59, 26, 272)
ic| simple_time_ser

In [None]:
writer_name = "nn_Attention_Same_Params"

writer_time = dt.now().strftime("%Y-%m-%dT%H:%M:%S")

train_summary_writer = SummaryWriter(
    "runs/" + writer_time + "_wow_" + writer_name + "_train"
)


test_set_key = random.PRNGKey(4454)

train_data_loader = DataLoader(train_ds, batch_size=512, shuffle=True)
batch_count = 0
for j in trange(2, desc=f"epochs for {train_summary_writer.log_dir}"):
    # arrs = train_data_loader()
    for i in tqdm(train_data_loader, leave=False, desc="batches"):
        # for i in trange(len(pre_train) // batch_size, leave=False):
        # for i in trange(len(pre_train) // batch_size //10, leave=False):
        # batch = make_batch(train_ds, i[0], 4)

        state, loss = train_step(state, jnp.array(i))
        if jnp.isnan(loss):
            raise ValueError("Nan Value in loss, stopping")
        batch_count += 1

        if batch_count % 1 == 0:
            train_summary_writer.add_scalar(
                "loss/loss", np.array(loss.item()), batch_count
            )
train_summary_writer.close()

epochs for runs/2024-05-20T00:40:31_wow_nn_Attention_Same_Params_train:   0%|          | 0/2 [00:00<?, ?it/s]

batches:   0%|          | 0/245 [00:00<?, ?it/s]

KeyboardInterrupt: 

In [None]:
import flax.linen as nn

?nn.MultiHeadAttention

[0;31mInit signature:[0m
[0mnn[0m[0;34m.[0m[0mMultiHeadAttention[0m[0;34m([0m[0;34m[0m
[0;34m[0m    [0mnum_heads[0m[0;34m:[0m [0mint[0m[0;34m,[0m[0;34m[0m
[0;34m[0m    [0mdtype[0m[0;34m:[0m [0mUnion[0m[0;34m[[0m[0mstr[0m[0;34m,[0m [0mtype[0m[0;34m[[0m[0mAny[0m[0;34m][0m[0;34m,[0m [0mnumpy[0m[0;34m.[0m[0mdtype[0m[0;34m,[0m [0mjax[0m[0;34m.[0m[0m_src[0m[0;34m.[0m[0mtyping[0m[0;34m.[0m[0mSupportsDType[0m[0;34m,[0m [0mAny[0m[0;34m,[0m [0mNoneType[0m[0;34m][0m [0;34m=[0m [0;32mNone[0m[0;34m,[0m[0;34m[0m
[0;34m[0m    [0mparam_dtype[0m[0;34m:[0m [0mUnion[0m[0;34m[[0m[0mstr[0m[0;34m,[0m [0mtype[0m[0;34m[[0m[0mAny[0m[0;34m][0m[0;34m,[0m [0mnumpy[0m[0;34m.[0m[0mdtype[0m[0;34m,[0m [0mjax[0m[0;34m.[0m[0m_src[0m[0;34m.[0m[0mtyping[0m[0;34m.[0m[0mSupportsDType[0m[0;34m,[0m [0mAny[0m[0;34m][0m [0;34m=[0m [0;34m<[0m[0;32mclass[0m [0;34m'jax.numpy.float32'[

In [None]:
from jax import random

test_key = jax.random.key(12)
att = random.normal(test_key, (2, 10, 6, 16))

In [None]:
test_result = jnp.squeeze(
    state.apply_fn({"params": state.params}, jnp.array([train_ds[0]]))
)
test_result

Array([[1.1705537, 1.1705537, 1.1705537, ..., 1.1705537, 1.1705537,
        1.1705537],
       [1.2930098, 1.293119 , 1.2930987, ..., 1.293119 , 1.293119 ,
        1.293119 ],
       [1.3365014, 1.3365445, 1.336489 , ..., 1.3365445, 1.3365445,
        1.3365445],
       ...,
       [2.2498384, 2.2498384, 2.2498384, ..., 2.2498384, 2.2498384,
        2.2498384],
       [2.252554 , 2.252554 , 2.252554 , ..., 2.252554 , 2.252554 ,
        2.252554 ],
       [2.2437117, 2.2437117, 2.2437117, ..., 2.2437117, 2.2437117,
        2.2437117]], dtype=float32)

In [None]:
jnp.array(train_ds[0]).shape, test_result.shape

((59, 26), (59, 26))

In [None]:
jnp.squeeze(test_result).shape

(59, 26)

In [None]:
df_pred = pd.DataFrame(test_result)
df_actual = pd.DataFrame(train_ds[0])

In [None]:
df_pred

Unnamed: 0,0,1,2,3,4,5,6,7,8,9,...,16,17,18,19,20,21,22,23,24,25
0,1.170554,1.170554,1.170554,1.170554,1.170554,1.170554,1.170554,1.170554,1.170554,1.170554,...,1.170554,1.170554,1.170554,1.170554,1.170554,1.170554,1.170554,1.170554,1.170554,1.170554
1,1.29301,1.293119,1.293099,1.293119,1.293119,1.293119,1.293119,1.293119,1.293119,1.293119,...,1.293119,1.293119,1.293119,1.293119,1.293119,1.293119,1.293119,1.293119,1.293119,1.293119
2,1.336501,1.336545,1.336489,1.336489,1.336545,1.33655,1.336524,1.336545,1.33655,1.33655,...,1.336545,1.336545,1.336545,1.336545,1.336545,1.336545,1.336545,1.336545,1.336545,1.336545
3,1.343069,1.343064,1.342984,1.343044,1.343052,1.342984,1.343002,1.343064,1.343072,1.343072,...,1.343087,1.343087,1.343087,1.343087,1.343087,1.343087,1.343087,1.343087,1.343087,1.343087
4,1.375174,1.375121,1.375313,1.375036,1.375317,1.375072,1.37522,1.375094,1.375217,1.375252,...,1.375198,1.375198,1.375198,1.375198,1.375198,1.375198,1.375198,1.375198,1.375198,1.375198
5,1.402991,1.429855,1.429755,1.429083,1.429697,1.429249,1.430241,1.429848,1.429969,1.430094,...,1.429824,1.429824,1.429824,1.429824,1.429824,1.429824,1.429824,1.429824,1.429824,1.429824
6,0.383199,1.477023,1.476209,1.465274,1.46835,1.466837,1.482053,1.47789,1.479513,1.479953,...,1.476856,1.476856,1.476856,1.476856,1.476856,1.476856,1.476856,1.476856,1.476856,1.476856
7,2.554446,1.451056,1.465172,1.350047,1.338477,1.325754,1.501209,1.470335,1.484572,1.486199,...,1.459361,1.459361,1.459361,1.459361,1.459361,1.459361,1.459361,1.459361,1.459361,1.459361
8,3.329627,1.222258,1.371737,0.955009,0.704728,0.637018,1.490662,1.362946,1.428571,1.433815,...,1.313279,1.313279,1.313279,1.313279,1.313279,1.313279,1.313279,1.313279,1.313279,1.313279
9,3.91757,0.776517,1.319047,0.65909,0.08437,-0.128863,1.511409,1.249411,1.401388,1.413216,...,1.131709,1.131709,1.131709,1.131709,1.131709,1.131709,1.131709,1.131709,1.131709,1.131709


In [None]:
df_actual

Unnamed: 0,0,1,2,3,4,5,6,7,8,9,...,16,17,18,19,20,21,22,23,24,25
0,0.0,1.56006,-0.854437,0.720639,0.691729,0.944008,2.700632,1.312562,1.944263,1.897807,...,,,,,,,,,,
1,0.465116,1.689858,-0.514359,0.333295,0.942289,0.681604,2.785811,1.312562,1.944263,1.897807,...,,,,,,,,,,
2,0.930233,1.753589,-0.154209,-0.124995,0.992368,0.412951,2.845461,1.312562,1.944263,1.897807,...,,,,,,,,,,
3,1.395349,1.748068,0.212022,-0.556775,0.831727,0.14054,2.879232,1.312562,1.944263,1.897807,...,,,,,,,,,,
4,1.860465,1.673573,0.569904,-0.870579,0.494812,-0.133144,2.887018,1.312562,1.944263,1.897807,...,,,,,,,,,,
5,2.325581,1.533811,0.905607,-1.000151,0.053178,-0.405638,2.868949,1.312562,1.944263,1.897807,...,,,,,,,,,,
6,2.790698,1.335526,1.206832,-0.918192,-0.399688,-0.67453,2.825381,1.312562,1.944263,1.897807,...,,,,,,,,,,
7,3.255814,1.087824,1.463513,-0.64198,-0.767958,-0.937474,2.756889,1.312562,1.944263,1.897807,...,,,,,,,,,,
8,3.72093,0.801365,1.668203,-0.229808,-0.973515,-1.192211,2.664256,1.312562,1.944263,1.897807,...,,,,,,,,,,
9,4.186047,0.487549,1.816168,0.231092,-0.972449,-1.436585,2.548458,1.312562,1.944263,1.897807,...,,,,,,,,,,


In [None]:
model_name = f"big_train_{dt.now()}_"

current_dir = os.getcwd()

if not os.path.exists("./pre_trained_models/"):
    os.makedirs("./pre_trained_models/")

path = os.path.join(current_dir, "./pre_trained_models/")


ckpt_dir = f"./pre_trained_models/{model_name}"

# checkpoints.save_checkpoint(
#     ckpt_dir=path, target=state, step=batch_count, overwrite=True, prefix=model_name
# )

In [None]:
def jax_array_memory_usage(array):
    """Calculate the memory usage of a JAX array in bytes."""
    # Get the number of bytes per element based on the data type
    bytes_per_element = array.dtype.itemsize
    # Calculate the total number of elements in the array
    total_elements = np.prod(array.shape)
    # Calculate total memory usage
    memory_usage_bytes = bytes_per_element * total_elements
    return memory_usage_bytes


cat_memory_usage = jax_array_memory_usage(batch.categorical)
num_memory_usage = jax_array_memory_usage(batch.numeric)
memory_usage = cat_memory_usage + num_memory_usage
memory_usage_gb = memory_usage / 1024 / 1024 / 1024
print(f"Memory usage: {memory_usage} bytes")
print(f"Memory usage: {memory_usage_gb} gb")

AttributeError: 'ArrayImpl' object has no attribute 'categorical'

In [None]:
import jax
import jax.numpy as jnp


def create_future_mask(seq_len):
    # Create a basic mask for future sequences
    mask = jnp.triu(jnp.ones((seq_len, seq_len)), k=1).astype(bool)
    return mask


def create_custom_mask(batch_size, seq_len, n_columns):
    # Create the future sequence mask
    future_mask = create_future_mask(seq_len)

    # Expand the mask to match the required dimensions
    mask = jnp.expand_dims(future_mask, axis=0)  # Shape: (1, seq_len, seq_len)
    mask = jnp.expand_dims(mask, axis=-1)  # Shape: (1, seq_len, seq_len, 1)
    mask = jnp.tile(
        mask, (batch_size, 1, 1, n_columns)
    )  # Shape: (batch_size, seq_len, seq_len, n_columns)

    return mask


# Example usage:
batch_size = 2
seq_len = 5
n_columns = 3
embedding_dim = 4

mask = create_custom_mask(batch_size, seq_len, n_columns)

print(mask)

[[[[False False False]
   [ True  True  True]
   [ True  True  True]
   [ True  True  True]
   [ True  True  True]]

  [[False False False]
   [False False False]
   [ True  True  True]
   [ True  True  True]
   [ True  True  True]]

  [[False False False]
   [False False False]
   [False False False]
   [ True  True  True]
   [ True  True  True]]

  [[False False False]
   [False False False]
   [False False False]
   [False False False]
   [ True  True  True]]

  [[False False False]
   [False False False]
   [False False False]
   [False False False]
   [False False False]]]


 [[[False False False]
   [ True  True  True]
   [ True  True  True]
   [ True  True  True]
   [ True  True  True]]

  [[False False False]
   [False False False]
   [ True  True  True]
   [ True  True  True]
   [ True  True  True]]

  [[False False False]
   [False False False]
   [False False False]
   [ True  True  True]
   [ True  True  True]]

  [[False False False]
   [False False False]
   [False False 

In [None]:
import jax
import jax.numpy as jnp
from jax import random


# Create the custom mask function
def create_future_mask(seq_len):
    # Create a basic mask for future sequences
    mask = jnp.triu(jnp.ones((seq_len, seq_len)), k=1).astype(bool)
    return mask


def create_custom_mask(batch_size, seq_len, n_columns):
    # Create the future sequence mask
    future_mask = create_future_mask(seq_len)

    # Expand the mask to match the required dimensions
    mask = jnp.expand_dims(future_mask, axis=0)  # Shape: (1, seq_len, seq_len)
    mask = jnp.expand_dims(mask, axis=-1)  # Shape: (1, seq_len, seq_len, 1)
    mask = jnp.tile(
        mask, (batch_size, 1, 1, n_columns)
    )  # Shape: (batch_size, seq_len, seq_len, n_columns)

    return mask


# Generate random input data
key = random.PRNGKey(0)
batch_size = 2
seq_len = 5
n_columns = 3
embedding_dim = 4

input_data = random.normal(key, (batch_size, seq_len, n_columns, embedding_dim))
print("Input Data:")
print(input_data)

# Create the mask
mask = create_custom_mask(batch_size, seq_len, n_columns)
print("Mask:")
print(mask)

# Apply the mask to the input data (for demonstration purposes, we'll just print the mask applied to some dummy attention scores)
dummy_attention_scores = random.normal(key, (batch_size, seq_len, seq_len, n_columns))
masked_attention_scores = jnp.where(mask, -jnp.inf, dummy_attention_scores)

print("Dummy Attention Scores:")
print(dummy_attention_scores)

print("Masked Attention Scores:")
print(masked_attention_scores)

In [None]:
from flax import linen as nn

?nn.MultiHeadAttention

[0;31mInit signature:[0m
[0mnn[0m[0;34m.[0m[0mMultiHeadAttention[0m[0;34m([0m[0;34m[0m
[0;34m[0m    [0mnum_heads[0m[0;34m:[0m [0mint[0m[0;34m,[0m[0;34m[0m
[0;34m[0m    [0mdtype[0m[0;34m:[0m [0mUnion[0m[0;34m[[0m[0mstr[0m[0;34m,[0m [0mtype[0m[0;34m[[0m[0mAny[0m[0;34m][0m[0;34m,[0m [0mnumpy[0m[0;34m.[0m[0mdtype[0m[0;34m,[0m [0mjax[0m[0;34m.[0m[0m_src[0m[0;34m.[0m[0mtyping[0m[0;34m.[0m[0mSupportsDType[0m[0;34m,[0m [0mAny[0m[0;34m,[0m [0mNoneType[0m[0;34m][0m [0;34m=[0m [0;32mNone[0m[0;34m,[0m[0;34m[0m
[0;34m[0m    [0mparam_dtype[0m[0;34m:[0m [0mUnion[0m[0;34m[[0m[0mstr[0m[0;34m,[0m [0mtype[0m[0;34m[[0m[0mAny[0m[0;34m][0m[0;34m,[0m [0mnumpy[0m[0;34m.[0m[0mdtype[0m[0;34m,[0m [0mjax[0m[0;34m.[0m[0m_src[0m[0;34m.[0m[0mtyping[0m[0;34m.[0m[0mSupportsDType[0m[0;34m,[0m [0mAny[0m[0;34m][0m [0;34m=[0m [0;34m<[0m[0;32mclass[0m [0;34m'jax.numpy.float32'[

In [None]:
import jax
import jax.numpy as jnp
from jax import random
from flax import linen as nn


class MyModel(nn.Module):
    n_heads: int
    embedding_dim: int

    @nn.compact
    def __call__(self, x, mask):
        batch_size, seq_len, n_columns, embedding_dim = x.shape

        # Merge seq_len and n_columns dimensions
        x_reshaped = x.reshape(batch_size, seq_len * n_columns, embedding_dim)

        # Project the inputs to query, key, and value
        qkv_features = self.embedding_dim // self.n_heads
        q = nn.DenseGeneral(features=(self.n_heads, qkv_features))(x_reshaped)
        k = nn.DenseGeneral(features=(self.n_heads, qkv_features))(x_reshaped)
        v = nn.DenseGeneral(features=(self.n_heads, qkv_features))(x_reshaped)

        # Compute attention logits
        attn_logits = jnp.einsum("bthd,bThd->bhtT", q, k)

        # Apply mask: broadcast mask to match attention logits shape
        attn_logits = jnp.where(mask, -jnp.inf, attn_logits)

        # Compute attention weights
        attn_weights = nn.softmax(attn_logits, axis=-1)

        # Compute the attention output
        attn_output = jnp.einsum("bhtT,bThd->bthd", attn_weights, v)

        # Combine heads and reshape back to original dimensions
        attn_output = attn_output.reshape(
            batch_size, seq_len, n_columns, self.embedding_dim
        )

        return attn_output


def create_future_mask(seq_len):
    # Create a basic mask for future sequences
    mask = jnp.triu(jnp.ones((seq_len, seq_len)), k=1).astype(bool)
    return mask


def create_custom_mask(batch_size, seq_len, n_columns, num_heads):
    # Create the future sequence mask
    future_mask = create_future_mask(seq_len)

    # Expand the mask to match the required dimensions for attention
    mask = jnp.expand_dims(future_mask, axis=0)  # Shape: (1, seq_len, seq_len)
    mask = jnp.expand_dims(mask, axis=1)  # Shape: (1, 1, seq_len, seq_len)
    mask = jnp.tile(
        mask, (batch_size, num_heads, 1, 1)
    )  # Shape: (batch_size, num_heads, seq_len, seq_len)

    return mask


# Generate random input data
key = random.PRNGKey(0)
batch_size = 2
seq_len = 5
n_columns = 3
embedding_dim = 16
num_heads = 4

input_data = random.normal(key, (batch_size, seq_len, n_columns, embedding_dim))
print("Input Data:")
print(input_data)

# Create the mask
mask = create_custom_mask(batch_size, seq_len, n_columns, num_heads)
print("Mask:")
print(mask)

# Initialize and apply the model
model = MyModel(n_heads=num_heads, embedding_dim=embedding_dim)
variables = model.init(key, input_data, mask)
output = model.apply(variables, input_data, mask)

print("Output:")
print(output)

Input Data:
[[[[-5.47799706e-01 -1.17179680e+00  1.45061789e-02  2.34819144e-01
     3.00504017e+00  1.99740767e-01 -6.31268919e-01 -2.68845528e-01
    -5.21487474e-01 -1.72483146e+00 -2.67246771e+00 -1.67209733e+00
    -1.23799145e-01 -7.75377810e-01  7.31753230e-01 -4.72956657e-01]
   [-7.59660363e-01 -1.43394160e+00  9.79405999e-01  2.70364374e-01
    -7.26617202e-02  1.33200324e+00  1.11467469e+00  5.46375573e-01
    -3.29602033e-01 -6.35213614e-01  8.06641698e-01  1.48840249e+00
     6.16176844e-01 -4.41124678e-01  5.02394319e-01 -4.20634806e-01]
   [-2.35249791e-02 -1.27048820e-01 -7.40331471e-01 -1.69789433e+00
    -5.50638080e-01  3.32435369e-01 -6.02500677e-01 -7.93417037e-01
    -1.91679239e+00  4.30762082e-01  2.93178469e-01  3.57544348e-02
     2.85551995e-01  1.14142168e+00 -6.61235869e-01  1.30349076e+00]]

  [[ 7.65341893e-02 -5.83034575e-01  3.56375240e-02 -1.11470902e+00
     6.09569371e-01 -1.52801716e+00  1.40855372e+00 -1.50325167e+00
    -1.90268196e-02  1.37721896

ValueError: Incompatible shapes for broadcasting: shapes=[(2, 4, 5, 5), (), (2, 4, 15, 15)]

In [None]:
import jax
import jax.numpy as jnp
from jax import random
from flax import linen as nn


class MyModel(nn.Module):
    n_heads: int

    @nn.compact
    def __call__(self, x, mask):
        # Apply multi-head attention with the custom mask
        attn = nn.MultiHeadAttention(
            num_heads=self.n_heads, qkv_features=16, use_bias=False
        )(x, mask=mask)
        return attn


def create_future_mask(seq_len):
    # Create a basic mask for future sequences
    mask = jnp.triu(jnp.ones((seq_len, seq_len)), k=1).astype(bool)
    return mask


def create_custom_mask(batch_size, seq_len, num_heads):
    # Create the future sequence mask
    future_mask = create_future_mask(seq_len)

    # Expand the mask to match the required dimensions for attention
    mask = jnp.expand_dims(future_mask, axis=0)  # Shape: (1, seq_len, seq_len)
    mask = jnp.expand_dims(mask, axis=1)  # Shape: (1, 1, seq_len, seq_len)
    mask = jnp.tile(
        mask, (batch_size, num_heads, 1, 1)
    )  # Shape: (batch_size, num_heads, seq_len, seq_len)

    return mask


# Generate random input data
key = random.PRNGKey(0)
batch_size = 2
seq_len = 5
n_columns = 3
embedding_dim = 16
num_heads = 4

input_data = random.normal(key, (batch_size, seq_len, n_columns, embedding_dim))
# print("Input Data:")
# print(input_data)

# Create the mask
mask = create_custom_mask(batch_size, seq_len, num_heads)
# print("Mask:")
# print(mask)

# Initialize and apply the model
model = MyModel(n_heads=num_heads)
variables = model.init(key, input_data, mask)
output = model.apply(variables, input_data, mask)

print("Output:")
print(output)

ValueError: Incompatible shapes for broadcasting: shapes=[(2, 4, 5, 5), (2, 5, 4, 3, 3), ()]

In [None]:
import jax
import jax.numpy as jnp
from jax import random
from flax import linen as nn


class MyModel(nn.Module):
    n_heads: int

    @nn.compact
    def __call__(self, x, mask):
        # Reshape input to combine seq_len and n_columns
        batch_size, seq_len, n_columns, embedding_dim = x.shape
        x = x.reshape(batch_size, seq_len * n_columns, embedding_dim)

        mask = nn.make_causal_mask(x)
        print(mask.shape)  #                                        (2, 15, 1, 16, 16)
        # ValueError: Incompatible shapes for broadcasting: shapes=[(2, 15, 1, 16, 16),
        #                                                           (2, 4, 15, 15), ()]
        mask = mask.reshape(batch_size, seq_len, n_columns, seq_len, n_columns)
        # Apply multi-head attention with the custom mask
        attn = nn.MultiHeadAttention(
            num_heads=self.n_heads, qkv_features=embedding_dim, use_bias=False
        )(x, mask=mask)
        print(attn.shape)
        # Reshape the output back to the original shape
        attn = attn.reshape(batch_size, seq_len, n_columns, embedding_dim)
        return attn


# Generate random input data
key = random.PRNGKey(0)
batch_size = 2
seq_len = 5
n_columns = 3
embedding_dim = 16
num_heads = 4

input_data = random.normal(key, (batch_size, seq_len, n_columns, embedding_dim))

# Create the mask
mask = create_custom_mask(batch_size, seq_len, num_heads, n_columns)

# Initialize and apply the model
model = MyModel(n_heads=num_heads)
variables = model.init(key, input_data, mask)
output = model.apply(variables, input_data, mask)

(2, 15, 1, 16, 16)


TypeError: cannot reshape array of shape (2, 15, 1, 16, 16) (size 7680) into shape (2, 5, 3, 5, 3) (size 450)

In [None]:
nn.make_causal_mask(jnp.ones((2, 5, 3, 16))).shape

(2, 5, 3, 1, 16, 16)

In [None]:
import jax.numpy as jnp
import flax.linen as nn
from flax.linen import partitioning
from typing import Any


class DecoderBlock(nn.Module):
    num_heads: int
    qkv_dim: int
    mlp_dim: int
    dropout_rate: float

    def setup(self):
        self.self_attention = nn.MultiHeadDotProductAttention(
            num_heads=self.num_heads,
            qkv_features=self.qkv_dim,
            kernel_init=nn.initializers.xavier_uniform(),
            bias_init=nn.initializers.normal(stddev=1e-6),
            dropout_rate=self.dropout_rate,
        )
        self.mlp = nn.Sequential(
            [nn.Dense(self.mlp_dim), nn.relu, nn.Dense(self.qkv_dim)]
        )
        self.layer_norm1 = nn.LayerNorm()
        self.layer_norm2 = nn.LayerNorm()
        self.dropout = nn.Dropout(rate=self.dropout_rate)

    def __call__(self, x, causal_mask, deterministic):
        # Self-attention block
        x = self.layer_norm1(x)
        attn_out = self.self_attention(
            query=x, key=x, value=x, mask=causal_mask, deterministic=deterministic
        )
        x = x + self.dropout(attn_out, deterministic=deterministic)

        # Feed-forward block
        x = self.layer_norm2(x)
        mlp_out = self.mlp(x)
        x = x + self.dropout(mlp_out, deterministic=deterministic)

        return x


class Decoder(nn.Module):
    vocab_size: int
    num_layers: int
    num_heads: int
    qkv_dim: int
    mlp_dim: int
    max_len: int
    dropout_rate: float

    def setup(self):
        self.token_embedding = nn.Embed(self.vocab_size, self.qkv_dim)
        self.position_embedding = nn.Embed(self.max_len, self.qkv_dim)
        self.decoder_blocks = [
            DecoderBlock(
                num_heads=self.num_heads,
                qkv_dim=self.qkv_dim,
                mlp_dim=self.mlp_dim,
                dropout_rate=self.dropout_rate,
            )
            for _ in range(self.num_layers)
        ]

    def __call__(self, x, deterministic=True):
        # Embedding and positional encoding
        seq_len = x.shape[1]
        x = self.token_embedding(x) + self.position_embedding(jnp.arange(seq_len))

        # Create the causal mask
        causal_mask = nn.make_causal_mask(x)

        # Apply decoder blocks
        for block in self.decoder_blocks:
            x = block(x, causal_mask, deterministic)

        return x


# Example usage
vocab_size = 32000
num_layers = 6
num_heads = 8
qkv_dim = 512
mlp_dim = 2048
max_len = 512
dropout_rate = 0.1

decoder = Decoder(
    vocab_size=vocab_size,
    num_layers=num_layers,
    num_heads=num_heads,
    qkv_dim=qkv_dim,
    mlp_dim=mlp_dim,
    max_len=max_len,
    dropout_rate=dropout_rate,
)

# Create a random input sequence of token IDs
import jax.random as random

key = random.PRNGKey(0)
x = random.randint(key, (1, max_len), 0, vocab_size)

# Forward pass
deterministic = True
output = decoder(x, deterministic=deterministic)

AttributeError: "Decoder" object has no attribute "token_embedding". If "token_embedding" is defined in '.setup()', remember these fields are only accessible from inside 'init' or 'apply'.

In [None]:
import jax
import jax.numpy as jnp
from flax import linen as nn
from flax.training import train_state
from typing import Any


class DecoderBlock(nn.Module):
    num_heads: int
    embed_dim: int
    mlp_dim: int

    @nn.compact
    def __call__(self, x, mask=None):
        # Self-attention block
        x = nn.LayerNorm()(x)
        x = nn.SelfAttention(
            num_heads=self.num_heads,
            qkv_features=self.embed_dim,
            kernel_init=nn.initializers.xavier_uniform(),
        )(x, mask=mask)
        x = nn.Dropout(0.1)(x, deterministic=True)
        residual = x

        # Feed-forward block
        x = nn.LayerNorm()(x)
        x = nn.Dense(self.mlp_dim, kernel_init=nn.initializers.xavier_uniform())(x)
        x = nn.relu(x)
        x = nn.Dense(self.embed_dim, kernel_init=nn.initializers.xavier_uniform())(x)
        x = nn.Dropout(0.1)(x, deterministic=True)

        return residual + x


class SimpleDecoder(nn.Module):
    num_layers: int
    num_heads: int
    embed_dim: int
    mlp_dim: int

    @nn.compact
    def __call__(self, x):
        # Create a causal mask
        mask = nn.make_causal_mask(x)

        for _ in range(self.num_layers):
            x = DecoderBlock(self.num_heads, self.embed_dim, self.mlp_dim)(x, mask)

        return x


# Example usage
key = jax.random.PRNGKey(0)
x = jax.random.normal(
    key, (1, 10, 64)
)  # Example input: batch size 1, sequence length 10, embedding size 64

model = SimpleDecoder(num_layers=2, num_heads=8, embed_dim=64, mlp_dim=256)
params = model.init(key, x)
y = model.apply(params, x)

print(y)

ValueError: Incompatible shapes for broadcasting: shapes=[(1, 10, 1, 64, 64), (1, 8, 10, 10), ()]

In [None]:
import jax
import jax.numpy as jnp
from flax import linen as nn
from flax.training import train_state
from typing import Any


class DecoderBlock(nn.Module):
    num_heads: int
    embed_dim: int
    mlp_dim: int

    @nn.compact
    def __call__(self, x, mask=None):
        # Self-attention block
        x = nn.LayerNorm()(x)
        x = nn.SelfAttention(
            num_heads=self.num_heads,
            qkv_features=self.embed_dim,
            kernel_init=nn.initializers.xavier_uniform(),
        )(x, mask=mask)
        x = nn.Dropout(0.1)(x, deterministic=True)
        residual = x

        # Feed-forward block
        x = nn.LayerNorm()(x)
        x = nn.Dense(self.mlp_dim, kernel_init=nn.initializers.xavier_uniform())(x)
        x = nn.relu(x)
        x = nn.Dense(self.embed_dim, kernel_init=nn.initializers.xavier_uniform())(x)
        x = nn.Dropout(0.1)(x, deterministic=True)

        return residual + x


class SimpleDecoder(nn.Module):
    num_layers: int
    num_heads: int
    embed_dim: int
    mlp_dim: int

    @nn.compact
    def __call__(self, x):
        # Create a causal mask
        mask = nn.make_causal_mask(x)

        for _ in range(self.num_layers):
            x = DecoderBlock(self.num_heads, self.embed_dim, self.mlp_dim)(x, mask)

        return x


# Example usage
key = jax.random.PRNGKey(0)
x = jax.random.normal(
    key, (1, 10, 64)
)  # Example input: batch size 1, sequence length 10, embedding size 64

model = SimpleDecoder(num_layers=2, num_heads=8, embed_dim=64, mlp_dim=256)
params = model.init(key, x)
y = model.apply(params, x)

print(y)

ValueError: Incompatible shapes for broadcasting: shapes=[(1, 10, 1, 64, 64), (1, 8, 10, 10), ()]

In [None]:
import tensorflow as tf
from tensorflow.keras.layers import (
    Layer,
    LayerNormalization,
    Dense,
    Dropout,
    MultiHeadAttention,
)


class DecoderBlock(Layer):
    def __init__(self, num_heads, embed_dim, mlp_dim, dropout_rate=0.1):
        super(DecoderBlock, self).__init__()
        self.num_heads = num_heads
        self.embed_dim = embed_dim
        self.mlp_dim = mlp_dim
        self.dropout_rate = dropout_rate

        self.attention = MultiHeadAttention(num_heads=num_heads, key_dim=embed_dim)
        self.dense_proj = tf.keras.Sequential(
            [Dense(mlp_dim, activation="relu"), Dense(embed_dim)]
        )
        self.layernorm1 = LayerNormalization()
        self.layernorm2 = LayerNormalization()
        self.dropout1 = Dropout(dropout_rate)
        self.dropout2 = Dropout(dropout_rate)

    def call(self, x, training, mask=None):
        attn_output = self.attention(x, x, attention_mask=mask)
        attn_output = self.dropout1(attn_output, training=training)
        out1 = self.layernorm1(x + attn_output)

        ffn_output = self.dense_proj(out1)
        ffn_output = self.dropout2(ffn_output, training=training)
        return self.layernorm2(out1 + ffn_output)


class SimpleDecoder(Layer):
    def __init__(self, num_layers, num_heads, embed_dim, mlp_dim, dropout_rate=0.1):
        super(SimpleDecoder, self).__init__()
        self.num_layers = num_layers
        self.embed_dim = embed_dim
        self.mlp_dim = mlp_dim

        self.dec_layers = [
            DecoderBlock(num_heads, embed_dim, mlp_dim, dropout_rate)
            for _ in range(num_layers)
        ]

    def call(self, x, training, mask=None):
        for i in range(self.num_layers):
            x = self.dec_layers[i](x, training, mask)
        return x


def create_causal_mask(seq_len):
    mask = tf.linalg.band_part(tf.ones((seq_len, seq_len)), -1, 0)
    return mask


# Example usage
batch_size = 1
seq_len = 10
embed_dim = 64

# Input tensor
x = tf.random.normal((batch_size, seq_len, embed_dim))

# Create the model
num_layers = 2
num_heads = 8
mlp_dim = 256
dropout_rate = 0.1

decoder = SimpleDecoder(num_layers, num_heads, embed_dim, mlp_dim, dropout_rate)

# Create a causal mask
causal_mask = create_causal_mask(seq_len)
causal_mask = causal_mask[tf.newaxis, tf.newaxis, :, :]  # (1, 1, seq_len, seq_len)

# Apply the decoder
output = decoder(x, training=True, mask=causal_mask)

print(output)

ValueError: Exception encountered when calling SimpleDecoder.call().

[1mOnly input tensors may be passed as positional arguments. The following argument value should be passed as a keyword argument: True (of type <class 'bool'>)[0m

Arguments received by SimpleDecoder.call():
  • x=tf.Tensor(shape=(1, 10, 64), dtype=float32)
  • training=True
  • mask=tf.Tensor(shape=(1, 1, 10, 10), dtype=float32)