# Deep Dive Tutorial: From Autograd to Transformers

Welcome to the second part of our exploration into PyTorch. In the previous session, we established PyTorch as a powerful library for GPU-accelerated tensor computations. Today, we will build upon that foundation to understand how PyTorch enables the creation and training of complex neural networks.

We will cover:

1.  **The Magic Behind PyTorch: Autograd and Computation Graphs:** How PyTorch automatically calculates gradients, the bedrock of modern neural network training.
2.  **Two Philosophies of Model Building:** We'll compare the standard Object-Oriented (`nn.Module`) approach with the flexible functional (`torch.func`) paradigm.
3.  **Introducing `einops`:** A powerful and elegant library for tensor manipulation that will make your code more readable, reliable, and expressive.
4.  **Putting It All Together: Building a Transformer:** We will use our knowledge and `einops` to construct a Transformer, the architecture behind models like ChatGPT and AlphaFold.

In [ ]:
!pip install einops graphviz -q

## 1. The Magic Behind PyTorch: Autograd

At the heart of any deep learning framework is the ability to perform **automatic differentiation**, or **Autograd**. When we train a neural network, we need to adjust its parameters (weights and biases) to minimize a loss function. This adjustment is typically done using an optimization algorithm like Gradient Descent, which requires computing the gradient of the loss function with respect to every parameter in the model.

PyTorch automates this entire process with `torch.autograd`. Here‚Äôs the core idea:

*   **Tracking Operations:** PyTorch keeps track of every operation performed on tensors that have their `requires_grad` attribute set to `True`.
*   **Building a DAG:** As operations are performed, PyTorch dynamically builds a **Directed Acyclic Graph (DAG)**. In this graph, the leaves are the input tensors (and model parameters), and the root is the output tensor (typically, the loss). The nodes in between represent the mathematical operations.
*   **Backpropagation with `.backward()`:** When you call `.backward()` on the final output tensor (e.g., `loss.backward()`), PyTorch traverses this graph backward from the root. It uses the **chain rule** of calculus to compute the gradients at each step and accumulates them in the `.grad` attribute of the leaf tensors (i.e., your model's parameters).



### Autograd's Computation Graph (DAG)

Here is a simplified view of the graph created during a forward pass and traversed during the backward pass.

In [ ]:
import graphviz

dot_source = """
digraph Autograd_DAG {
    rankdir=TB;
    node [shape=box, style="rounded,filled", fillcolor="lightblue"];
    edge [fontsize=10];

    subgraph cluster_forward {
        label = "Forward Pass: Building the Graph";
        style=filled;
        color=lightgrey;
        node [fillcolor="azure"];
        X [label="Input Tensor X"];
        W [label="Parameters W\nrequires_grad=True", style="rounded,filled", fillcolor="lightpink"];
        MatMul [label="z = X @ W"];
        LossFn [label="Loss = L(z)"];
        X -> MatMul;
        W -> MatMul;
        MatMul -> LossFn;
    }

    subgraph cluster_backward {
        label = "Backward Pass: Traversing the Graph";
        style=filled;
        color=lightgrey;
        node [fillcolor="honeydew"];
        Grad_Z [label="Compute dL/dz"];
        Grad_W [label="Compute dL/dW (Chain Rule)"];
        W_Update [label="Accumulate in W.grad", style="rounded,filled", fillcolor="lightpink"];
        LossFn -> Grad_Z [label="loss.backward()"];
        Grad_Z -> Grad_W;
        Grad_W -> W_Update;
    }
}
"""
graph = graphviz.Source(dot_source)
graph

## 2. Two Philosophies of Model Building

PyTorch offers two primary ways to structure your deep learning models: the traditional object-oriented approach and a more recent functional approach. Let's explore both by building a simple Multi-Layer Perceptron (MLP) to predict housing prices using the California Housing dataset.

In [ ]:
import torch
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F
from torch.func import grad

# For loading real data
from sklearn.datasets import fetch_california_housing
from sklearn.model_selection import train_test_split
from sklearn.preprocessing import StandardScaler

# --- 0. Data Loading and Preprocessing ---
print("--- Loading and preparing data ---")
housing = fetch_california_housing()
X, y = housing.data, housing.target
X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.2, random_state=42)
scaler = StandardScaler()
X_train = scaler.fit_transform(X_train)
X_test = scaler.transform(X_test)
X_train = torch.tensor(X_train, dtype=torch.float32)
y_train = torch.tensor(y_train, dtype=torch.float32).view(-1, 1)
X_test = torch.tensor(X_test, dtype=torch.float32)
y_test = torch.tensor(y_test, dtype=torch.float32).view(-1, 1)

# --- Hyperparameters ---
INPUT_FEATURES = X_train.shape[1]
HIDDEN_FEATURES = 64
OUTPUT_FEATURES = 1
LEARNING_RATE = 0.001
EPOCHS = 20

## 1. The nn.Module Approach (Stateful Objects)
print("\n--- Training with nn.Module ---")

class SimpleMLP(nn.Module):
    def __init__(self):
        super().__init__()
        self.layer1 = nn.Linear(INPUT_FEATURES, HIDDEN_FEATURES)
        self.activation = nn.ReLU()
        self.layer2 = nn.Linear(HIDDEN_FEATURES, OUTPUT_FEATURES)

    def forward(self, x):
        x = self.layer1(x)
        x = self.activation(x)
        x = self.layer2(x)
        return x

model = SimpleMLP()
loss_fn_mse = nn.MSELoss()
optimizer = optim.Adam(model.parameters(), lr=LEARNING_RATE)

for epoch in range(EPOCHS):
    model.train()
    y_pred = model(X_train)
    loss = loss_fn_mse(y_pred, y_train)
    optimizer.zero_grad()
    loss.backward()
    optimizer.step()
    if (epoch + 1) % 5 == 0:
        print(f"Epoch [{epoch+1}/{EPOCHS}], Loss: {loss.item():.4f}")

model.eval()
with torch.no_grad():
    y_test_pred = model(X_test)
    test_loss = loss_fn_mse(y_test_pred, y_test)
    print(f"Final Test MSE (nn.Module): {test_loss.item():.4f}")

## 2. The torch.func Approach (Stateless Functions)
print("\n--- Training with torch.func ---")

def init_params():
    params = {
        'w1': torch.randn(HIDDEN_FEATURES, INPUT_FEATURES) * (2 / INPUT_FEATURES)**0.5,
        'b1': torch.zeros(HIDDEN_FEATURES),
        'w2': torch.randn(OUTPUT_FEATURES, HIDDEN_FEATURES) * (2 / HIDDEN_FEATURES)**0.5,
        'b2': torch.zeros(OUTPUT_FEATURES)
    }
    return params

def mlp_fn(params, x):
    x = F.linear(x, params['w1'], params['b1'])
    x = F.relu(x)
    x = F.linear(x, params['w2'], params['b2'])
    return x

def compute_loss(params, x, y):
    y_pred = mlp_fn(params, x)
    return F.mse_loss(y_pred, y)

grad_fn = grad(compute_loss)
functional_params = init_params()

for epoch in range(EPOCHS):
    grads = grad_fn(functional_params, X_train, y_train)
    with torch.no_grad():
        for key in functional_params:
            functional_params[key] -= LEARNING_RATE * grads[key]
    if (epoch + 1) % 5 == 0:
        loss = compute_loss(functional_params, X_train, y_train)
        print(f"Epoch [{epoch+1}/{EPOCHS}], Loss: {loss.item():.4f}")

with torch.no_grad():
    y_test_pred = mlp_fn(functional_params, X_test)
    test_loss = F.mse_loss(y_test_pred, y_test)
    print(f"Final Test MSE (torch.func): {test_loss.item():.4f}")

### Key Takeaway

Both methods achieve the same goal. üëç

*   **`nn.Module`** is the standard, convenient choice for most projects. It bundles state (parameters) and logic (the `forward` method) together in an object.

*   **`torch.func`** decouples state and logic. This provides greater flexibility and is essential for advanced techniques like meta-learning or custom per-sample gradient computations, where you need to manipulate model parameters in more complex ways.

# Einops tutorial, part 1: basics
## Welcome to einops-land!

We don't write 
```python
y = x.transpose(0, 2, 3, 1)
```
We write comprehensible code
```python
y = rearrange(x, 'b c h w -> b h w c')
```


`einops` supports widely used tensor packages (such as `numpy`, `pytorch`, `jax`, `tensorflow`), and extends them.

## What's in this tutorial?

- fundamentals: reordering, composition and decomposition of axes
- operations: `rearrange`, `reduce`, `repeat`
- how much you can do with a single operation!

## Preparations

To run this notebook, you will need to download the resources from the einops repository.

In [ ]:
import numpy as np
from IPython import get_ipython
from IPython.display import display_html
from PIL.Image import fromarray
from einops import rearrange, reduce, repeat

# The following code is inlined from the original tutorial's utils.py
def display_np_arrays_as_images():
    def np_to_png(a):
        if 2 <= len(a.shape) <= 3:
            return fromarray(np.array(np.clip(a, 0, 1) * 255, dtype="uint8"))._repr_png_()
        else:
            return fromarray(np.zeros([1, 1], dtype="uint8"))._repr_png_()

    def np_to_text(obj, p, cycle):
        if len(obj.shape) < 2:
            print(repr(obj))
        if 2 <= len(obj.shape) <= 3:
            pass
        else:
            print(f"<array of shape {obj.shape}>")

    # This will only work in IPython environments
    try:
        ipy = get_ipython()
        if ipy is None:
            return
        ipy.display_formatter.formatters["image/png"].for_type(np.ndarray, np_to_png)
        ipy.display_formatter.formatters["text/plain"].for_type(np.ndarray, np_to_text)
    except ImportError: pass

_style_inline = """<style>
.einops-answer {
    color: transparent;
    padding: 5px 15px;
    background-color: #def;
}
.einops-answer:hover { color: blue; }
</style>"""

def guess(x):
    display_html(
        _style_inline + f"<h4>Answer is: <span class='einops-answer'>{tuple(x)}</span> (hover to see)</h4>",
        raw=True,
    )

# Now run the functions
display_np_arrays_as_images()

# Download resources
!mkdir -p resources
!wget https://raw.githubusercontent.com/arogozhnikov/einops/main/docs/resources/test_images.npy -P resources/

## Load a batch of images to play with

In [ ]:
ims = np.load("./resources/test_images.npy", allow_pickle=False)
print(ims.shape, ims.dtype)

In [ ]:
# display the first image
ims[0]

In [ ]:
# rearrange, as the name suggests, rearranges elements
rearrange(ims[0], "h w c -> w h c")

## Composition of axes

In [ ]:
# einops allows seamlessly composing batch and height to a new height dimension
rearrange(ims, "b h w c -> (b h) w c")

In [ ]:
# resulting dimensions are computed very simply
rearrange(ims, "b h w c -> h (b w) c").shape

## Decomposition of axis

In [ ]:
# decomposition is the inverse process - represent an axis as a combination of new axes
rearrange(ims, "(b1 b2) h w c -> b1 b2 h w c ", b1=2).shape

## Meet einops.reduce

In [ ]:
# average over batch
reduce(ims, "b h w c -> h w c", "mean")

In [ ]:
# max-pooling with a kernel 2x2
reduce(ims, "b (h h1) (w w1) c -> b h w c", "max", h1=2, w1=2)

## Repeating elements

In [ ]:
# repeat along a new axis. New axis can be placed anywhere
repeat(ims[0], "h w c -> h new_axis w c", new_axis=5).shape

# Einops tutorial, part 2: deep learning

Previous part of tutorial provides visual examples with numpy.

## What's in this tutorial?

- working with deep learning packages
- important cases for deep learning models
- `einops.asnumpy` and `einops.layers`

In [ ]:
import torch
import numpy as np
from einops import rearrange, reduce
x_np = np.random.RandomState(42).normal(size=[10, 32, 100, 200])
x = torch.from_numpy(x_np)
x.requires_grad = True

## Backpropagation

- Gradients are a corner stone of deep learning
- You can back-propagate through einops operations

In [ ]:
y0 = x
y1 = reduce(y0, "b c h w -> b c", "max")
y2 = rearrange(y1, "b c -> c b")
y3 = reduce(y2, "c b -> ", "sum")
y3.backward()
print(reduce(x.grad, "b c h w -> ", "sum"))

## Common building blocks of deep learning

Let's check how some familiar operations can be written with `einops`

**Flattening** is common operation, frequently appears at the boundary
between convolutional layers and fully connected layers

In [ ]:
y = rearrange(x, "b c h w -> b (c h w)")
y.shape

**space-to-depth**

In [ ]:
y = rearrange(x, "b c (h h1) (w w1) -> b (h1 w1 c) h w", h1=2, w1=2)
y.shape

**depth-to-space** (notice that it's reverse of the previous)

In [ ]:
y = rearrange(x, "b (h1 w1 c) h w -> b c (h h1) (w w1)", h1=2, w1=2)
y.shape

## Layers

For frameworks that prefer operating with layers, `einops` layers are available.


In [ ]:
from torch.nn import Sequential, Conv2d, MaxPool2d, Linear, ReLU
from einops.layers.torch import Rearrange

# A simple LeNet-style model for image classification
model_with_einops = Sequential(
    # Assuming input is (batch, 3, 32, 32)
    Conv2d(3, 6, kernel_size=5), # -> (batch, 6, 28, 28)
    MaxPool2d(kernel_size=2),   # -> (batch, 6, 14, 14)
    Conv2d(6, 16, kernel_size=5),# -> (batch, 16, 10, 10)
    MaxPool2d(kernel_size=2),   # -> (batch, 16, 5, 5)
    # Flatten the feature map
    Rearrange('b c h w -> b (c h w)'), 
    Linear(16*5*5, 120), 
    ReLU(),
    Linear(120, 10), 
)

In [ ]:
# Let's check that the model works as expected
dummy_input = torch.randn(1, 3, 32, 32)
output = model_with_einops(dummy_input)
print(model_with_einops)
print("\nOutput shape:", output.shape)

## 3. Putting It All Together: Building a Transformer

The Transformer architecture has become the foundation of modern AI. Its core component is the **Multi-Head Self-Attention (MHSA)** mechanism. Before writing the code, let's build a strong intuition for how `einops` and `einsum` make implementing attention remarkably elegant.

### Bridging Theory and Code: Attention with `einops` and `einsum`

At its heart, attention is a mechanism to compute a weighted sum of values, where the weights are dynamically computed based on the similarity between a "query" and all "keys".

**The Steps of Attention:**

1.  **Similarity Score:** We compute the dot product between each Query (Q) and all Keys (K). This is a perfect use case for `torch.einsum`, which expresses this complex batch matrix multiplication concisely:
    *   `torch.einsum('b h i d, b h j d -> b h i j', q, k)`
    *   **Translation:** For each item in the **b**atch and each **h**ead, multiply the query token `i` (of dimension **d**) with each key token `j` (of dimension **d**) to produce a similarity matrix of shape (`b`, `h`, `i`, `j`).

2.  **Scaling:** We scale the scores by dividing by the square root of the head dimension (`sqrt(d_k)`). This stabilizes the gradients during training.

3.  **Softmax:** The scores are converted into probabilities (weights that sum to 1).

4.  **Weighted Sum:** We multiply the attention scores with the Values (V). `einsum` again makes this clear:
    *   `torch.einsum('b h i j, b h j d -> b h i d', attn, v)`
    *   **Translation:** For each item in the **b**atch and each **h**ead, multiply the attention scores (`i`, `j`) with each value token `j` (of dimension **d**) to produce a new weighted representation for token `i`.

**The Role of `rearrange` in Multi-Head Attention:**

Multi-head attention requires splitting a single large Q, K, or V tensor into multiple smaller "heads". `rearrange` handles this decomposition and the reverse composition elegantly:

*   **Splitting:** `rearrange(qkv, 'b n (h d) -> b h n d', h=num_heads)`
    *   This takes a tensor of shape `(batch, sequence_len, heads * head_dim)` and splits the last dimension into two new ones: `h` and `d`.
*   **Combining:** `rearrange(out, 'b h n d -> b n (h d)')`
    *   This performs the inverse operation, merging the `h` and `d` dimensions back together.

This combination of `einsum` for the core logic and `rearrange` for structuring the data is what makes `einops` so powerful for building Transformers.

### Scaled Dot-Product Attention

In [ ]:
dot_source = '''
digraph G {
    rankdir=TB;
    node [shape=box, style=filled, fillcolor="lightblue"];
    Q [label="Query"]; K [label="Key"]; V [label="Value"];
    MatMul1 [label="MatMul"];
    TransposeK [label="Transpose"];
    Scale [label="Scale by 1/‚àöd_k"];
    OptionalMask [label="Optional Mask", style=dashed];
    Softmax [fillcolor="lightyellow"];
    MatMul2 [label="MatMul"];
    Output [shape=ellipse, fillcolor="lightgreen"];
    Q -> MatMul1; K -> TransposeK -> MatMul1;
    MatMul1 -> Scale -> OptionalMask -> Softmax -> MatMul2;
    V -> MatMul2 -> Output;
    labelloc="t"; label="Scaled Dot-Product Attention";
}
'''
graphviz.Source(dot_source)

### Multi-Head Attention

In [ ]:
dot_source = '''
digraph G {
    rankdir=TB;
    node [shape=box, style=rounded];
    subgraph cluster_input {
        label = "Input Projection"; style=filled; color=lightgrey;
        Input [label="Input x"];
        LinearQ [label="Linear_Q"]; LinearK [label="Linear_K"]; LinearV [label="Linear_V"];
        Input -> LinearQ -> Q;
        Input -> LinearK -> K;
        Input -> LinearV -> V;
    }
    subgraph cluster_heads {
        label = "Parallel Attention Heads"; style=filled; color=lightgrey;
        node [style=filled, fillcolor="lightblue"];
        Head1 [label="Head 1\nScaled Dot-Product"];
        Head2 [label="Head 2\nScaled Dot-Product"];
        HeadN [label="...\nHead n"];
    }
    Q -> Head1; K -> Head1; V -> Head1;
    Q -> Head2; K -> Head2; V -> Head2;
    Q -> HeadN; K -> HeadN; V -> HeadN;
    subgraph cluster_output {
        label = "Output Stage"; style=filled; color=lightgrey;
        Concat [label="Concatenate"];
        LinearOut [label="Final Linear Layer"];
        OutputMHA [label="Multi-Head Output", shape=ellipse, style=filled, fillcolor="lightgreen"];
    }
    Head1 -> Concat; Head2 -> Concat; HeadN -> Concat;
    Concat -> LinearOut -> OutputMHA;
}
'''
graphviz.Source(dot_source)

### Architecture 1: Transformer Encoder

The Encoder's job is to map an input sequence of symbol representations (x‚ÇÅ, ..., x‚Çô) to a sequence of continuous representations z = (z‚ÇÅ, ..., z‚Çô). It is composed of a stack of identical layers, each having two sub-layers: a multi-head self-attention mechanism and a simple, position-wise fully connected feed-forward network. Residual connections and layer normalization are used around each of the two sub-layers.

In [ ]:
dot_source = '''
digraph G {
    rankdir=TB;
    node [shape=box, style=rounded];
    X_in [label="Input x"];
    AddNorm1 [label="Add & Norm", shape=circle];
    AddNorm2 [label="Add & Norm", shape=circle];
    X_out [label="Output"];
    MHA [label="Multi-Head\nAttention", style=filled, fillcolor="lightpink"];
    FF [label="Feed Forward", style=filled, fillcolor="skyblue"];
    LN1 [label="LayerNorm"]; LN2 [label="LayerNorm"];
    X_in -> LN1 -> MHA -> AddNorm1;
    X_in -> AddNorm1 [label=" Residual"];
    AddNorm1 -> LN2 -> FF -> AddNorm2;
    AddNorm1 -> AddNorm2 [label=" Residual"];
    AddNorm2 -> X_out;
    labelloc="t";
    label="Transformer Encoder Block";
}
'''
graphviz.Source(dot_source)

In [ ]:
import torch
from torch import nn
from einops import rearrange

# 1. Multi-Head Self-Attention (MHSA)
class Attention(nn.Module):
    def __init__(self, dim, n_heads, head_dim):
        super().__init__()
        self.n_heads = n_heads
        inner_dim = n_heads * head_dim
        self.to_qkv = nn.Linear(dim, inner_dim * 3, bias=False)
        self.scale = head_dim ** -0.5

    def forward(self, x):
        # x: (batch, sequence, dimension)
        qkv = self.to_qkv(x).chunk(3, dim=-1)
        q, k, v = map(lambda t: rearrange(t, 'b n (h d) -> b h n d', h=self.n_heads), qkv)

        # Scaled Dot-Product Attention using einsum
        dots = torch.einsum('b h i d, b h j d -> b h i j', q, k) * self.scale
        attn = dots.softmax(dim=-1)
        out = torch.einsum('b h i j, b h j d -> b h i d', attn, v)

        # Reshape and return
        out = rearrange(out, 'b h n d -> b n (h d)')
        return out

# 2. Feed-Forward Network
class FeedForward(nn.Module):
    def __init__(self, dim, hidden_dim):
        super().__init__()
        self.net = nn.Sequential(
            nn.Linear(dim, hidden_dim),
            nn.GELU(),
            nn.Linear(hidden_dim, dim)
        )
    def forward(self, x):
        return self.net(x)

# 3. Transformer Encoder Block
class Transformer(nn.Module):
    def __init__(self, dim, n_heads, head_dim, mlp_dim):
        super().__init__()
        self.norm1 = nn.LayerNorm(dim)
        self.attn = Attention(dim, n_heads, head_dim)
        self.norm2 = nn.LayerNorm(dim)
        self.ff = FeedForward(dim, mlp_dim)

    def forward(self, x):
        # Attention block with pre-normalization and residual connection
        x = self.attn(self.norm1(x)) + x
        # Feed-forward block with pre-normalization and residual connection
        x = self.ff(self.norm2(x)) + x
        return x

# --- Example Usage ---
batch_size = 1
sequence_length = 10
embedding_dim = 64
num_heads = 8
head_dimension = 8
mlp_hidden_dim = 128

input_tensor = torch.randn(batch_size, sequence_length, embedding_dim)

transformer_encoder = Transformer(
    dim=embedding_dim,
    n_heads=num_heads,
    head_dim=head_dimension,
    mlp_dim=mlp_hidden_dim
)

output = transformer_encoder(input_tensor)
print("Input Shape:", input_tensor.shape)
print("Output Shape:", output.shape)

### Architecture 2: Transformer Decoder

The Decoder is also composed of a stack of identical layers. In addition to the two sub-layers in each encoder layer, the decoder inserts a third sub-layer, which performs multi-head cross-attention over the output of the encoder stack. The self-attention sub-layer in the decoder stack is also modified to prevent positions from attending to subsequent positions (Masked Attention).

In [ ]:
dot_source = '''
digraph G {
    rankdir=TB;
    node [shape=box, style=rounded];
    Tgt_in [label="Target Input"];
    Context [label="Encoder Context"];
    AddNorm1 [label="Add & Norm", shape=circle];
    AddNorm2 [label="Add & Norm", shape=circle];
    AddNorm3 [label="Add & Norm", shape=circle];
    Tgt_out [label="Output"];
    MMHA [label="Masked Multi-Head\nSelf-Attention", style=filled, fillcolor="lightpink"];
    MHA [label="Multi-Head\nCross-Attention", style=filled, fillcolor="lightblue"];
    FF [label="Feed Forward", style=filled, fillcolor="skyblue"];
    LN1 [label="LayerNorm"]; LN2 [label="LayerNorm"]; LN3 [label="LayerNorm"];
    Tgt_in -> LN1 -> MMHA -> AddNorm1;
    Tgt_in -> AddNorm1 [label=" Residual"];
    AddNorm1 -> LN2 -> MHA -> AddNorm2;
    AddNorm1 -> AddNorm2 [label=" Residual"];
    Context -> MHA;
    AddNorm2 -> LN3 -> FF -> AddNorm3;
    AddNorm2 -> AddNorm3 [label=" Residual"];
    AddNorm3 -> Tgt_out;
    labelloc="t";
    label="Transformer Decoder Block";
}
'''
graphviz.Source(dot_source)

In [ ]:
class MaskedAttention(nn.Module):
    def __init__(self, dim, n_heads, head_dim):
        super().__init__()
        self.n_heads = n_heads
        inner_dim = n_heads * head_dim
        self.to_qkv = nn.Linear(dim, inner_dim * 3, bias=False)
        self.scale = head_dim ** -0.5

    def forward(self, x):
        qkv = self.to_qkv(x).chunk(3, dim=-1)
        q, k, v = map(lambda t: rearrange(t, 'b n (h d) -> b h n d', h=self.n_heads), qkv)

        dots = torch.einsum('b h i d, b h j d -> b h i j', q, k) * self.scale

        # Masking logic
        mask = torch.ones_like(dots, dtype=torch.bool).triu_(1)
        dots.masked_fill_(mask, float('-inf'))

        attn = dots.softmax(dim=-1)
        out = torch.einsum('b h i j, b h j d -> b h i d', attn, v)
        out = rearrange(out, 'b h n d -> b n (h d)')
        return out

class CrossAttention(nn.Module):
    def __init__(self, dim, n_heads, head_dim):
        super().__init__()
        self.n_heads = n_heads
        inner_dim = n_heads * head_dim
        self.to_q = nn.Linear(dim, inner_dim, bias=False)
        self.to_kv = nn.Linear(dim, inner_dim * 2, bias=False)
        self.scale = head_dim ** -0.5

    def forward(self, x, context):
        # Q from decoder, K and V from encoder context
        q = self.to_q(x)
        k, v = self.to_kv(context).chunk(2, dim=-1)
        q, k, v = map(lambda t: rearrange(t, 'b n (h d) -> b h n d', h=self.n_heads), (q, k, v))

        dots = torch.einsum('b h i d, b h j d -> b h i j', q, k) * self.scale
        attn = dots.softmax(dim=-1)
        out = torch.einsum('b h i j, b h j d -> b h i d', attn, v)
        out = rearrange(out, 'b h n d -> b n (h d)')
        return out

# Decoder-Only block (e.g., for a GPT-like model without cross-attention)
class DecoderOnlyBlock(nn.Module):
    def __init__(self, dim, n_heads, head_dim, mlp_dim):
        super().__init__()
        self.norm1 = nn.LayerNorm(dim)
        self.attn = MaskedAttention(dim, n_heads, head_dim)
        self.norm2 = nn.LayerNorm(dim)
        self.ff = FeedForward(dim, mlp_dim)

    def forward(self, x):
        x = self.attn(self.norm1(x)) + x
        x = self.ff(self.norm2(x)) + x
        return x


### Architecture 3: Encoder-Decoder Transformer

This is the full architecture, combining the Encoder and Decoder stacks. The output of the top encoder is transformed into a set of attention vectors K and V. These are used in each decoder's cross-attention layer, allowing the decoder to focus on appropriate places in the input sequence.

In [ ]:
dot_source = '''
digraph G {
    rankdir=TB;
    graph [compound=true];
    node [shape=box, style=rounded];
    subgraph cluster_encoder {
        label = "Encoder Stack";
        rankdir=LR;
        InputSeq [label="Input Sequence"];
        EmbIn [label="Embedding"];
        Enc1 [label="Encoder Block 1", style=filled, fillcolor="lightpink"];
        EncN [label="Encoder Block N", style=filled, fillcolor="lightpink"];
        ContextOut [label="Context (K, V)"];
        InputSeq -> EmbIn -> Enc1 -> EncN -> ContextOut;
    }
    subgraph cluster_decoder {
        label = "Decoder Stack";
        rankdir=LR;
        OutputSeq [label="Output Sequence"];
        EmbOut [label="Embedding"];
        Dec1 [label="Decoder Block 1", style=filled, fillcolor="lightblue"];
        DecN [label="Decoder Block N", style=filled, fillcolor="lightblue"];
        OutputProbs [label="Linear + Softmax"];
        OutputSeq -> EmbOut -> Dec1 -> DecN -> OutputProbs;
    }
    ContextOut -> Dec1 [lhead=cluster_decoder, ltail=cluster_encoder];
}
'''
graphviz.Source(dot_source)

In [ ]:
# Full Decoder Block for an Encoder-Decoder model
class TransformerDecoderBlock(nn.Module):
    def __init__(self, dim, n_heads, head_dim, mlp_dim):
        super().__init__()
        self.norm1 = nn.LayerNorm(dim)
        self.masked_attn = MaskedAttention(dim, n_heads, head_dim)
        self.norm2 = nn.LayerNorm(dim)
        self.cross_attn = CrossAttention(dim, n_heads, head_dim)
        self.norm3 = nn.LayerNorm(dim)
        self.ff = FeedForward(dim, mlp_dim)

    def forward(self, x, context):
        # Masked self-attention
        x = self.masked_attn(self.norm1(x)) + x
        # Cross-attention with encoder context
        x = self.cross_attn(self.norm2(x), context) + x
        # Feed-forward
        x = self.ff(self.norm3(x)) + x
        return x

# Full Encoder-Decoder Model
class EncoderDecoder(nn.Module):
    def __init__(self, num_encoder_layers, num_decoder_layers, dim, n_heads, head_dim, mlp_dim):
        super().__init__()
        self.encoder = nn.ModuleList([Transformer(dim, n_heads, head_dim, mlp_dim) for _ in range(num_encoder_layers)])
        self.decoder = nn.ModuleList([TransformerDecoderBlock(dim, n_heads, head_dim, mlp_dim) for _ in range(num_decoder_layers)])

    def forward(self, src, tgt):
        # Encode the source sequence
        encoded_src = src
        for encoder_layer in self.encoder:
            encoded_src = encoder_layer(encoded_src)

        # Decode using the encoded source as context
        decoded_tgt = tgt
        for decoder_layer in self.decoder:
            decoded_tgt = decoder_layer(decoded_tgt, encoded_src)
        return decoded_tgt

# --- Example Usage ---
src_seq_len = 15
tgt_seq_len = 20

source_seq = torch.randn(batch_size, src_seq_len, embedding_dim)
target_seq = torch.randn(batch_size, tgt_seq_len, embedding_dim)

encoder_decoder = EncoderDecoder(
    num_encoder_layers=3, num_decoder_layers=3,
    dim=embedding_dim, n_heads=num_heads,
    head_dim=head_dimension, mlp_dim=mlp_hidden_dim
)

output = encoder_decoder(source_seq, target_seq)
print("Source Shape:", source_seq.shape)
print("Target Shape:", target_seq.shape)
print("Final Output Shape:", output.shape)