### Patchify
cut 2D image $H*W*C$ ($H$ height, $W$ width, $C$ channel) into many patches, flattening the patch and project it into dimension $D$, getting a series of tokens for transformer to operate.

Code reference: 

I read the original paper about the logic of the Patchify in ViT, and searched for other people's work of writing it. Since the patch embedding logic is the same in DiT and ViT

https://github.com/vballoli/vit-flax/blob/main/vit_flax/vit.py

ChatGPT learning the functions 

In [None]:
import jax.numpy as jnp
import flax.linen as nn

class PatchEmbed(nn.Module):
    patch_size: int
    dimension: int

    @nn.compact
    def __call__(self, x):
        # x: shape (B, H, W, C)
        B, H, W, C = x.shape
        patches_h = H // self.patch_size
        patches_w = W // self.patch_size
        total_num_patches = patches_h * patches_w
        
        # 1. reshape the images into patch grids
        # -> (B, patches_h, patch, patches_w, patch, C)
        x = x.reshape(B, patches_h, self.patch_size, patches_w, self.patch_size, C)

        # 2. transpose to bring patch dimensions together
        # -> (B, patches_h, patches_w, patch, patch, C)
        x = jnp.transpose(x, (0, 1, 3, 2, 4, 5))

        # 3. flatten patch pixels -> (B, total_num_patches, patch_size*patch_size*C)
        x = x.reshape(B, total_num_patches, self.patch_size * self.patch_size * C)

        # 4. linear projection to "dimension" for Transformer
        x = nn.Dense(self.dimension)(x)

        # 5. learned positional embedding
        pos = self.param("pos", nn.initializers.normal(0.02), (1, total_num_patches, self.dim))
        x = x + pos

        return x, (patches_h, patches_w)


Reference code

https://github.com/kelechi-c/dit_flow/tree/main


# Transformer

implementing Transformer Block
1. LayerNorm
2. Multi-Head Self Attention
3. Residual Connection
4. LayerNorm + MLP
5. Residual Connection

### Multi-Head Self Attention

Reference Code

https://uvadlc-notebooks.readthedocs.io/en/latest/tutorial_notebooks/JAX/tutorial6/Transformers_and_MHAttention.html

I learned how to write the Jax version for multihead attention (also learning some jax syntax such as implementing the scaled dot product using Jax) from this tutorial;
I also used ChatGPT for helping me to clarify the meaning of the functions and some logic, as well as debugging

In [None]:
import math
import jax.numpy as jnp
from flax import linen as nn

class MultiHeadSelfAttenton(nn.Module):
    dimension: int
    num_heads: int
    dropout: float = 0.0

    @nn.compact
    def __call__(self, x, mask=None, deterministic=True):
        # x.shape = (batchsize, number of tokens, dimension)
        # every token is a D-dimension vector, including the information of an image patch
        B, N, D = x.shape
        H = self.num_heads
        head_dimension = D // H

        assert D % H == 0, "dimension must be divisible by number of heads"

        # QKV projection
        qkv = nn.Dense(D * 3)(x)  # qkv.shape = (B, N, 3 * D)
        qkv = qkv.reshape(B, N, 3, H, head_dimension)
        qkv = qkv.transpose(2, 0, 3, 1, 4)  # qkv.shape = (3, B, H, N, head_dimension)
        q, k, v = qkv[0], qkv[1], qkv[2]

        # Scaled dot-product attention
        values, attention = self.scaled_dot_product(q, k, v, mask=mask)
        # Apply dropout to attention output
        values = nn.Dropout(rate=self.dropout)(values, deterministic=deterministic)

        # Concatenate heads and project back to D
        values = values.transpose(0, 2, 1, 3).reshape(B, N, D)
        output = nn.Dense(D)(values)

        return output, attention

    @staticmethod
    def scaled_dot_product(q, k, v, mask=None):
        # q, k, v shape = (B, H, N, head_dim)
        d_k = q.shape[-1]

        # Step 1: compute attention logits = Q @ K^T / sqrt(d_k)
        attn_logits = jnp.matmul(q, jnp.swapaxes(k, -2, -1))
        attn_logits = attn_logits / math.sqrt(d_k)

        # Step 2: apply mask if provided (mask 0 = ignore)
        if mask is not None:
            attn_logits = jnp.where(mask == 0, -9e15, attn_logits)

        # Step 3: softmax to get normalized attention weights
        attention = nn.softmax(attn_logits, axis=-1)

        # Step 4: weighted sum of values
        values = jnp.matmul(attention, v)

        return values, attention
