---
title: "Implementing a Perciever in JAX for fun ✨"
author:
  - name: "Tugdual Kerjan"
    url: https://tugdual.fr
    email: tkerjan@outlook.com
date: "November 15, 2024"
number-sections: true
reference-location: margin
toc: true
format: 
  html:
    standalone: true
    embed-resources: true
    self-contained-math: true
    code-fold: false
    code-tools: true
execute:
  output:
    false
bibliography: assets/bib.bibtex
theme: united
github: "https://github.com/TugdualKerjan/ResNet-for-JAX"
lightbox: true
---

For the full project, visit the [GitHub repository](https://github.com/TugdualKerjan/Adventures-in-Equinox).

# Context 👀

I'm trying to rewrite XTTS in JAX to understand how it works. 

We are going to implement the Perceiver Resampler, a model introduced by [@alayrac2022flamingovisuallanguagemodel] to address the inefficiency of attention mechanisms in processing high-dimensional or multimodal inputs. Unlike traditional transformers, here input features are projected into a fixed-size latent array using a resampling module, which efficiently extracts and compresses information of varying input size. We thus avoid having to constrain ourselves to a fixed input. This work builds on the Perciever [@jaegle2021perceivergeneralperceptioniterative], a paper proposing fixed latent arrays as input to attention modules, improving the costly quadratic compute costs of attention modules. 

::: {.column-margin}

![A high level overview of Flamingo, a model based on this concept from [@alayrac2022flamingovisuallanguagemodel]](assets/arch.png)

:::

__Attention__

The most important part of this model is the concept of attention: It takes the input and learns to find which parts are relevant to one another. Seeing two eyes and a mouth in your input can allow you to infer that there is a face for example.

__Resampling__

My favorite part - to be agnostic to input length, the input is flattened and the equivalent of positional embeddings are added.

![Resampling of the input [@alayrac2022flamingovisuallanguagemodel]](assets/resampler.png)

# Goal 🎯

Train a Perciever Resampler on ??? !



We start by importing our favorite libraries:

In [None]:
%load_ext autoreload
%autoreload 2

In [13]:
import jax
import jax.numpy as jnp
import equinox as eqx
import equinox.nn as nn
from typing import Optional
import torch
import numpy as np


# config.update("jax_enable_x64", True)

## GEGLU activation

In [14]:
class GEGLU(eqx.Module):
    def __call__(self, x):
        x, gate = jnp.split(x, 2, axis=-1)
        return jax.nn.gelu(gate, approximate=False) * x

We can test that this works correctly with some quick code below

In [3]:
from TTS.tts.layers.xtts.perceiver_encoder import GEGLU as theirGEGLU

our_x = jax.random.normal(jax.random.PRNGKey(1), shape=(30,))
their_x = torch.from_numpy(np.array(our_x))

ours = GEGLU()
theirs = theirGEGLU()

our_y = GEGLU()(our_x)
their_y = theirGEGLU()(their_x)

torch.testing.assert_close(their_y, torch.from_numpy(np.array(our_y)))

# CausalConv1d

Since our model has to predict the next token, having convolution layers that map surrounding tokens would allow the model to cheat by incoporating future token information during the convolutions. Because of this, we create a custom layer that pads away the future tokens. 

In [15]:
class CausalConv1d(eqx.nn.Conv1d):
    causal_padding: int

    def __init__(self, *args, **kwargs):
        super().__init__(*args, **kwargs)
        (kernel_size,) = self.kernel_size
        (dilation,) = self.dilation
        (stride,) = self.stride

        assert stride == 1
        self.causal_padding = dilation * (kernel_size - 1)

    def __call__(self, x: jax.Array, *, key: Optional[jax.Array] = None) -> jax.Array:
        causal_padded_x = jax.numpy.pad(
            x, ((0, 0), (self.causal_padding, 0)), mode="constant", constant_values=0.0
        )
        # print(causal_padded_x.shape)
        return super().__call__(causal_padded_x, key=key)

In [6]:
from TTS.tts.layers.xtts.perceiver_encoder import CausalConv1d as theirGEGLU
import torch
import numpy as np

our_x = jax.random.normal(jax.random.PRNGKey(1), shape=(2, 10))
their_x = torch.from_numpy(np.array(our_x))

ours = CausalConv1d(2, 2, 3, key=jax.random.PRNGKey(2))
theirs = theirGEGLU(2, 2, 3)

their_params = {key: value for key, value in theirs.named_parameters()}


def update_weights(path, x):
    path = ".".join([str(p).strip("[].") for p in path])
    if path in their_params:
        if "bias" in path:
            return jnp.expand_dims(their_params[path].detach().numpy(), -1)
        return jnp.array(their_params[path].detach().numpy())
    print(path)
    return x


ours = jax.tree_util.tree_map_with_path(update_weights, ours)

our_y = ours(our_x)
their_y = theirs(their_x)

torch.testing.assert_close(their_y, torch.from_numpy(np.array(our_y)))

causal_padding


# Attention Mechanism

In [17]:
class Attend(eqx.Module):
    dropout: float
    causal: bool
    attn_dropout: nn.Dropout

    def __init__(self, dropout=0.0, causal=False, use_flash=False):
        self.dropout = dropout
        self.attn_dropout = eqx.nn.Dropout(dropout, inference=True)

        self.causal = causal

    def get_mask(self, n):
        return jnp.triu(jnp.ones((n, n), dtype=bool), k=1)

    def __call__(self, q, k, v, mask=None):
        n = q.shape[-2]
        scale = q.shape[-1] ** -0.5
        kq = jnp.matmul(q, jnp.transpose(k, (0, 2, 1))) * scale
        # Key mask
        if mask is not None:
            mask = jnp.expand_dims(mask, 0)
            kq = jnp.where(mask, kq, jnp.zeros_like(mask))

        if self.causal:
            kq = jax.numpy.where(self.get_mask(n), kq, -jnp.finfo(kq.dtype).max)

        attn = jax.nn.softmax(kq, axis=-1)
        attn = self.attn_dropout(attn)

        out = jnp.matmul(attn, v)

        return out

In [10]:
from TTS.tts.layers.xtts.perceiver_encoder import Attend as theirAttend
import torch
import numpy as np

ours = Attend(0.1, causal=False, use_flash=False)
theirs = theirAttend(0.1, causal=False, use_flash=False)

x = jax.random.normal(jax.random.PRNGKey(1), shape=(6, 24, 10))
our_x, our_k, our_v = jax.numpy.split(x, 3, axis=0)
their_q = torch.from_numpy(np.array(our_x))
their_k = torch.from_numpy(np.array(our_k))
their_v = torch.from_numpy(np.array(our_v))


their_params = {key: value for key, value in theirs.named_parameters()}


def update_weights(path, x):
    path = ".".join([str(p).strip("[].") for p in path])
    if path in their_params:
        if "bias" in path:
            return jnp.expand_dims(their_params[path].detach().numpy(), -1)
        return jnp.array(their_params[path].detach().numpy())
    print(path)
    return x


ours = jax.tree_util.tree_map_with_path(update_weights, ours)
our_y = ours(our_x, our_k, our_v)

their_q = their_q.unsqueeze(0)
their_k = their_k.unsqueeze(0)
their_v = their_v.unsqueeze(0)

their_y = theirs(their_q, their_k, their_v)

their_y = their_y.squeeze(0)

torch.testing.assert_close(their_y, torch.from_numpy(np.array(our_y)))

dropout
causal
attn_dropout.p
attn_dropout.inference
(2, 24, 10)


In [18]:
from functools import partial
from einops import rearrange


class Attention(eqx.Module):

    cross_attn_include_queries: bool
    scale: float
    heads: int

    attend: Attend
    to_q: nn.Linear
    to_kv: nn.Linear
    to_out: nn.Linear

    dim_inner: int

    def __init__(
        self,
        dim,
        *,
        dim_context=None,
        causal=False,
        dim_head=64,
        heads=8,
        dropout=0.0,
        use_flash=False,
        cross_attn_include_queries=False,
        key=None,
    ):
        key1, key2, key3 = jax.random.split(key, 3)

        self.scale = dim_head**-0.5
        self.heads = heads
        self.cross_attn_include_queries = cross_attn_include_queries
        self.dim_inner = dim_head * heads

        self.attend = Attend(dropout, causal)

        self.to_q = nn.Linear(dim, self.dim_inner, use_bias=False, key=key1)
        self.to_kv = nn.Linear(dim, self.dim_inner * 2, use_bias=False, key=key2)
        self.to_out = nn.Linear(self.dim_inner, dim, use_bias=False, key=key3)

    # @partial(jax.jit, static_argnums=2)
    def __call__(self, x, context, mask=None):

        # Should the kv, cross attention, include the query values ?
        context = jnp.concat([x, context], axis=-2)
        q, k, v = (
            jax.vmap(jax.vmap(self.to_q))(x),
            *jnp.split(jax.vmap(jax.vmap(self.to_kv))(context), 2, axis=-1),
        )
        # q = jnp.reshape(q, shape=(q.shape[0], self.heads, q.shape[-2], -1))

        q, k, v = map(
            lambda t: rearrange(t, "b n (h d) -> b h n d", h=self.heads), (q, k, v)
        )

        # k = jnp.reshape(k, shape=(k.shape[0], self.heads, k.shape[-2], -1))
        # v = jnp.reshape(v, shape=(v.shape[0], self.heads, v.shape[-2], -1))
        out = jax.vmap(self.attend)(q, k, v, mask)
        out = rearrange(out, "b h n d -> b n (h d)")

        return jax.vmap(jax.vmap(self.to_out))(out)

In [91]:
from TTS.tts.layers.xtts.perceiver_encoder import Attention as theirAttention
import torch
import equinox.nn as nn
import numpy as np

ours = Attention(
    6,
    causal=False,
    dim_head=4,
    heads=2,
    dropout=0,
    cross_attn_include_queries=True,
    key=jax.random.PRNGKey(1),
)
theirs = theirAttention(
    6, causal=False, dim_head=4, heads=2, dropout=0, cross_attn_include_queries=True
)

x = jax.random.normal(jax.random.PRNGKey(1), shape=(2, 10, 6))

our_latents, our_x = jax.numpy.split(x, 2, axis=1)

their_latents = torch.from_numpy(np.array(our_latents))
their_x = torch.from_numpy(np.array(our_x))
their_q = their_q.unsqueeze(0)
their_k = their_k.unsqueeze(0)

their_params = {key: value for key, value in theirs.named_parameters()}


def update_weights(path, x):
    path = ".".join([str(p).strip("[].") for p in path])
    if path in their_params:
        if "bias" in path:
            return jnp.expand_dims(their_params[path].detach().numpy(), -1)
        return jnp.array(their_params[path].detach().numpy())
    print(path)
    return x


ours = jax.tree_util.tree_map_with_path(update_weights, ours)

%timeit our_y = ours(our_latents, our_x, mask=None)

%timeit their_y = theirs(their_latents, their_x)

# their_y = their_y.squeeze(0)

torch.testing.assert_close(their_y, torch.from_numpy(np.array(our_y)))

cross_attn_include_queries
scale
heads
attend.dropout
attend.causal
attend.attn_dropout.p
attend.attn_dropout.inference
dim_inner
2.78 ms ± 62.3 μs per loop (mean ± std. dev. of 7 runs, 100 loops each)
214 μs ± 16.3 μs per loop (mean ± std. dev. of 7 runs, 1,000 loops each)


In [19]:
class FeedForward(eqx.Module):
    causal_conv: bool
    ff1: nn.Linear
    ff2: nn.Linear
    act: GEGLU
    conv: CausalConv1d

    def __init__(self, dim, mult=4, causal_conv=False, key=None):
        key1, key2, key3 = jax.random.split(key, 3)

        self.causal_conv = causal_conv
        dim_inner = int(dim * mult * 2 / 3)
        self.conv = CausalConv1d(dim_inner, dim_inner, 3, key=key3)
        self.act = GEGLU()
        self.ff1 = nn.Linear(dim, dim_inner * 2, key=key1)
        self.ff2 = nn.Linear(dim_inner, dim, key=key2)

    def __call__(self, x):
        y = jax.vmap(self.ff1)(x)
        y = self.act(y)
        if self.causal_conv:
            y = jnp.permute_dims(y, (1, 0))
            y = self.conv(y)
            y = jnp.permute_dims(y, (1, 0))
        y = jax.vmap(self.ff2)(y)

        return y

In [157]:
from TTS.tts.layers.xtts.perceiver_encoder import FeedForward as theirFF
import torch
import equinox.nn as nn
import numpy as np

ours = FeedForward(30, 4, False, key=jax.random.PRNGKey(1))
theirs = theirFF(30, 4, False)

our_x = jax.random.normal(jax.random.PRNGKey(1), shape=(2, 160, 30))
their_x = torch.from_numpy(np.array(our_x))

their_params = {key: value for key, value in theirs.named_parameters()}

mapping = {
    "ff1.weight": "0.weight",
    "ff1.bias": "0.bias",
    "ff2.weight": "2.weight",
    "ff2.bias": "2.bias",
}


def update_weights(path, x):
    path = ".".join([str(p).strip("[].") for p in path])
    if path in their_params:
        if "bias" in path and "conv" not in path:
            return jnp.expand_dims(their_params[path].detach().numpy(), -1)
        return jnp.array(their_params[path].detach().numpy())
    if path in mapping.keys():
        if "bias" in path:
            return jnp.array(their_params[mapping[path]].detach().numpy())
        return jnp.array(their_params[mapping[path]].detach().numpy())
    return x


def print_shapes(path, x):
    path = ".".join([str(p).strip("[].") for p in path])
    print(path)

    if "weight" in path or "bias" in path:
        print(f"Param {path} has shape {x.shape}")


jax.tree_util.tree_map_with_path(print_shapes, ours)
ours = jax.tree_util.tree_map_with_path(update_weights, ours)
jax.tree_util.tree_map_with_path(print_shapes, ours)


our_y = jax.vmap(ours)(our_x)
their_y = theirs(their_x)

# their_y = their_y.squeeze(0)

torch.testing.assert_close(their_y, torch.from_numpy(np.array(our_y)))

causal_conv
ff1.weight
Param ff1.weight has shape (160, 30)
ff1.bias
Param ff1.bias has shape (160,)
ff2.weight
Param ff2.weight has shape (30, 80)
ff2.bias
Param ff2.bias has shape (30,)
conv.weight
Param conv.weight has shape (80, 80, 3)
conv.bias
Param conv.bias has shape (80, 1)
conv.causal_padding
causal_conv
ff1.weight
Param ff1.weight has shape (160, 30)
ff1.bias
Param ff1.bias has shape (160,)
ff2.weight
Param ff2.weight has shape (30, 80)
ff2.bias
Param ff2.bias has shape (30,)
conv.weight
Param conv.weight has shape (80, 80, 3)
conv.bias
Param conv.bias has shape (80, 1)
conv.causal_padding
Incoming (160, 30)
Incoming (160, 160)


In [20]:
from einops import repeat


class PerceiverResampler(eqx.Module):

    proj_context: jax.Array
    latents: jax.Array
    layers: list
    norm: RMSNorm

    def __init__(
        self,
        *,
        dim,
        depth=2,
        dim_context=None,
        num_latents=32,
        dim_head=64,
        heads=8,
        ff_mult=4,
        use_flash_attn=False,
        key=None,
    ):

        key1, key2, key3 = jax.random.split(key, 3)
        if dim_context is None:
            dim_context = dim

        self.proj_context = (
            nn.Linear(dim_context, dim, key=key1)
            if dim != dim_context
            else nn.Identity()
        )

        self.latents = jax.random.normal(key3, (num_latents, dim))

        self.layers = [
            (
                Attention(
                    dim=dim,
                    dim_head=dim_head,
                    heads=heads,
                    use_flash=use_flash_attn,
                    cross_attn_include_queries=True,
                    key=y1,
                ),
                FeedForward(dim=dim, mult=ff_mult, key=y1),
            )
            for y1 in jax.random.split(key2, depth)
        ]

        self.norm = RMSNorm(dim)

    def __call__(self, x, mask=None):
        # print(f"Shape of x: {x.shape}")
        y = jax.vmap(self.proj_context)(x)
        # print(f"Shape of y: {y.shape}")
        latents = repeat(self.latents, "n d -> b n d", b=x.shape[0])
        # print(f"Shape of latent: {self.latents.shape}")
        # latents = j

        for attn, ff in self.layers:
            print(latents[0, 0])

            latents = attn(latents, y, mask) + latents
            print(latents[0, 0])
            latents = jax.vmap(ff)(latents) + latents
        return jax.vmap(jax.vmap(self.norm))(latents)

In [22]:
import torch
import jax
import numpy as np
import jax
import equinox as eqx

# Exported from .ipynb using python3 export.py (Exports all cells with the export tag)
from TTS.tts.layers.xtts.trainer.gpt_trainer import (
    GPTArgs,
    GPTTrainerConfig,
    GPTTrainer,
    XttsAudioConfig,
)

# init args and config
model_args = GPTArgs(
    max_conditioning_length=132300,  # 6 secs
    min_conditioning_length=11025,  # 0.5 secs
    debug_loading_failures=False,
    max_wav_length=220000,  # ~11.6 seconds
    max_text_length=200,
    mel_norm_file="checkpoints/mel_stats.pth",
    dvae_checkpoint="checkpoints/dvae.pth",
    xtts_checkpoint="./checkpoints/model.pth",  # checkpoint path of the model that you want to fine-tune
    tokenizer_file="checkpoints/vocab.json",
    gpt_num_audio_tokens=1026,
    gpt_start_audio_token=1024,
    gpt_stop_audio_token=1025,
    gpt_use_masking_gt_prompt_approach=True,
    gpt_use_perceiver_resampler=True,
)
# define audio config
audio_config = XttsAudioConfig(
    sample_rate=22050, dvae_sample_rate=22050, output_sample_rate=24000
)
# training parameters config

config = GPTTrainerConfig()

config.load_json("checkpoints/config.json")

config.epochs = 3
config.output_path = "checkpoints"
config.model_args = model_args
config.audio = audio_config
config.batch_size = 2
config.num_loader_workers = 8
config.eval_split_max_size = 256
config.print_step = 50
config.plot_step = 100
config.log_model_step = 100
config.save_step = 100
config.save_n_checkpoints = 1
config.save_checkpoints = True
config.print_eval = False
config.optimizer = "AdamW"
config.optimizer_params = {
    "betas": [0.9, 0.96],
    "eps": 1e-8,
    "weight_decay": 0.99,
}
config.lr = 1e-4
config.lr_scheduler = "MultiStepLR"
config.lr_scheduler_params = {
    "milestones": [50000 * 18, 150000 * 18, 300000 * 18],
    "gamma": 0.5,
    "last_epoch": -1,
}
config.test_sentences = []

# init the model from config
model = GPTTrainer.init_from_config(config).xtts.gpt.conditioning_perceiver

<fsspec.implementations.local.LocalFileOpener object at 0x3c3161bd0>
<fsspec.implementations.local.LocalFileOpener object at 0x3174f2230>
<fsspec.implementations.local.LocalFileOpener object at 0x3174f3640>
<_io.BufferedReader name='checkpoints/dvae.pth'>
>> DVAE weights restored from: checkpoints/dvae.pth
<fsspec.implementations.local.LocalFileOpener object at 0x3174f1f00>


In [None]:
%load_ext autoreload
%autoreload 2

In [23]:
ours = PerceiverResampler(
    dim=1024,
    depth=2,
    dim_context=1024,
    num_latents=32,
    dim_head=64,
    heads=8,
    ff_mult=4,
    use_flash_attn=False,
    key=jax.random.PRNGKey(1),
)

our_x = jax.random.normal(jax.random.PRNGKey(1), shape=(2, 32, 1024))
their_x = torch.from_numpy(np.array(our_x))

their_params = {key: value for key, value in model.named_parameters()}
print(their_params.keys())

print(model.latents.size())
mapping = {
    "ff1.weight": "0.weight",
    "ff1.bias": "0.bias",
    "ff2.weight": "2.weight",
    "ff2.bias": "2.bias",
}


def update_weights(path, x):
    seq = [str(p).strip("[].") for p in path]
    path = ".".join(seq)
    if path in their_params:
        print(their_params[path].size())
        # if "bias" in path:
        #     return jnp.expand_dims(their_params[path].detach().numpy(), -1)
        return jnp.array(their_params[path].detach().numpy())
    if "layers" == seq[0] and "1" == seq[2]:
        if "ff1" == seq[3]:
            if "bias" in path:
                return jnp.array(
                    their_params[".".join([seq[0], seq[1], seq[2], "0", "bias"])]
                    .detach()
                    .numpy(),
                )
            else:
                return jnp.array(
                    their_params[".".join([seq[0], seq[1], seq[2], "0", "weight"])]
                    .detach()
                    .numpy()
                )
        if "ff2" == seq[3]:
            if "bias" in path:
                return jnp.array(
                    their_params[".".join([seq[0], seq[1], seq[2], "2", "bias"])]
                    .detach()
                    .numpy(),
                )
            else:
                return jnp.array(
                    their_params[".".join([seq[0], seq[1], seq[2], "2", "weight"])]
                    .detach()
                    .numpy()
                )
    if path in mapping.keys():
        if "bias" in path:
            return jnp.array(their_params[mapping[path]].detach().numpy())
        return jnp.array(their_params[mapping[path]].detach().numpy())
    # print(path)
    return x


def print_shapes(path, x):
    path = ".".join([str(p).strip("[].") for p in path])
    if "weight" in path or "bias" in path:
        print(f"Path {path} and shape {x.shape}")


# print(f"Latents: {ours.latents[0,0]}")
# jax.tree_util.tree_map_with_path(print_shapes, ours)
ours = jax.tree_util.tree_map_with_path(update_weights, ours)
# print(f"Latents: {ours.latents[0,0]}")
# jax.tree_util.tree_map_with_path(print_shapes, ours)
eqx.tree_serialise_leaves("xttsperciever.eqx", ours)
ours = eqx.tree_deserialise_leaves("xttsperciever.eqx", ours)
# print(f"Latents: {ours.latents.shape}")


our_y = ours(our_x)
their_y = model(their_x)

print(our_y[0])
print(their_y[0])
# their_y = their_y.squeeze(0)

torch.testing.assert_close(their_y, torch.from_numpy(np.array(our_y)))

dict_keys(['latents', 'layers.0.0.to_q.weight', 'layers.0.0.to_kv.weight', 'layers.0.0.to_out.weight', 'layers.0.1.0.weight', 'layers.0.1.0.bias', 'layers.0.1.2.weight', 'layers.0.1.2.bias', 'layers.1.0.to_q.weight', 'layers.1.0.to_kv.weight', 'layers.1.0.to_out.weight', 'layers.1.1.0.weight', 'layers.1.1.0.bias', 'layers.1.1.2.weight', 'layers.1.1.2.bias', 'norm.gamma'])
torch.Size([32, 1024])
torch.Size([32, 1024])
torch.Size([512, 1024])
torch.Size([1024, 1024])
torch.Size([1024, 512])
torch.Size([512, 1024])
torch.Size([1024, 1024])
torch.Size([1024, 512])
[ 0.01236053  0.01541404 -0.00212941 ... -0.00038619 -0.00196526
 -0.00434487]
[ 0.06295125  0.04312283  0.05844267 ... -0.11432284 -0.08922151
 -0.08559718]
[ 0.00548464 -0.04037053 -0.13012525 ...  0.21880269 -0.38889858
 -0.07342979]
[-4.2671423   2.1818194   0.6935338  ...  8.672872    0.17315355
 -1.6407791 ]
torch.Size([2, 32, 1024])
tensor([ 0.0124,  0.0154, -0.0021,  ..., -0.0004, -0.0020, -0.0043],
       grad_fn=<Select

AssertionError: Tensor-likes are not close!

Mismatched elements: 65467 / 65536 (99.9%)
Greatest absolute difference: 4.343595027923584 at index (1, 16, 440) (up to 1e-05 allowed)
Greatest relative difference: 1.0725865364074707 at index (0, 10, 440) (up to 1.3e-06 allowed)