# 🔢 HW 5: Learning to Integrate Using Transformers

<font color='red'>**Deadline: 22.05.2024**</font>

In this notebook, we explore the task of **symbolic integration** using transformer-based models. We will be working with a pretrained transformer developed by Facebook AI Research (Meta AI) that was specifically designed for mathematical tasks.

As discussed previously, training such models from scratch is computationally expensive. Therefore, we will use an existing pretrained model and demonstrate how it can be used to perform integration tasks.



## 📦 Step 1: Clone the Repository and Import Transformer Components


We begin by cloning the GitHub repository that contains the necessary model definitions, pretrained weights, and utility scripts for working with mathematical problems.

```python
!git clone https://github.com/facebookresearch/SymbolicMathematics
```

In [None]:
!git clone https://github.com/facebookresearch/SymbolicMathematics.git

In [None]:
import os
import numpy as np
import sympy as sp
import torch
from logging import getLogger

Now that we have the repository cloned, we’ll load the core transformer building blocks and utility functions. Although the implementation may differ slightly from our lecture examples, the overall API and workflow remain the same.


In [None]:
from SymbolicMathematics.src.model import build_modules
from SymbolicMathematics.src.utils import to_cuda
from SymbolicMathematics.src.envs.sympy_utils import simplify
from SymbolicMathematics.src.model.transformer import TransformerModel

TransformerModel.STORE_OUTPUTS = True

## 🧮 Step 2: Dataset Generation

To train and evaluate the transformer on symbolic integration tasks, the authors of the paper proposed **three distinct methods** for constructing function-integral pairs:

### 🔁 1. Forward Generation (FWD)
- Randomly generate a function.
- Use a **Computer Algebra System (CAS)** to compute its integral.
- If the CAS cannot compute the integral, discard the pair.
- ✅ Ensures the result is a correct integral.
- ❌ May lead to a biased dataset, as only "easy-to-integrate" functions are retained.

### 🔄 2. Backward Generation (BWD)
- Randomly generate a function **\( F \)** (that has known structure).
- Compute its **derivative** \( f = F' \).
- Use the pair \( (f, F) \) as a training sample.
- ✅ Guarantees correctness.
- ❌ Trains the model to "undo differentiation", which is not the same as general integration.

### 🧩 3. Backward Generation with Integration by Parts (IBP)
- Generate two random functions \( F \) and \( G \).
- Compute their derivatives \( f = F',\ g = G' \).
- If either \( f \cdot G \) or \( F \cdot g \) is already in the dataset, apply:
  \[
  \int Fg = FG - \int fG
  \]
- ✅ Encourages more diverse integration strategies (like substitution or parts).
- ✅ Mimics human mathematical reasoning.

---

### 📌 Next Step: Choose a Model

We will now select one of the pretrained transformer models for evaluation and testing on integration problems.


<div align="center">
  <img src="https://drive.google.com/uc?id=1wK3J_CPmJkRlYyWK0qojAwCSEF4XmcPv" alt="Integration Transformer Pipeline" width="600"/>
</div>

Model links in that order:

https://dl.fbaipublicfiles.com/SymbolicMathematics/models/fwd.pth

https://dl.fbaipublicfiles.com/SymbolicMathematics/models/bwd.pth

https://dl.fbaipublicfiles.com/SymbolicMathematics/models/ibp.pth

https://dl.fbaipublicfiles.com/SymbolicMathematics/models/fwd_bwd.pth

https://dl.fbaipublicfiles.com/SymbolicMathematics/models/ibp_bwd.pth

https://dl.fbaipublicfiles.com/SymbolicMathematics/models/fwd_bwd_ibp.pth

https://dl.fbaipublicfiles.com/SymbolicMathematics/models/ode1.pth

https://dl.fbaipublicfiles.com/SymbolicMathematics/models/ode2.pth

To use a pretrained transformer model for integration, we download the weights for the **FWD+BWD** model — a combination of forward and backward generation strategies.

This model has demonstrated strong performance in generalizing across a wide range of symbolic integration tasks.

We will store the weights in a path that we can later pass to the model loader.

In [None]:
# Download the pretrained weights (FWD+BWD model)
!wget https://dl.fbaipublicfiles.com/SymbolicMathematics/models/fwd_bwd.pth -P pretrained_models/

In [None]:
# Define the path to the downloaded weights
MODEL_PATH = "pretrained_models/fwd_bwd.pth"

## 🧠 Step 3: Load the Transformer Model

We now define the configuration for our transformer model. These parameters are taken directly from the official [SymbolicMathematics GitHub repository](https://github.com/facebookresearch/SymbolicMathematics/blob/main/README.md).


We also specify that the model will run on **CPU** for compatibility, but you may switch to CUDA if available by setting `'cpu': False`.

The configuration is wrapped using the `AttrDict` class, which enables attribute-style access.

In [None]:
from SymbolicMathematics.src.utils import AttrDict

model_path = MODEL_PATH  # Make sure this path points to your pretrained weights

params = AttrDict({
    # environment parameters
    'env_name': 'char_sp',
    'int_base': 10,
    'balanced': False,
    'positive': True,
    'precision': 10,
    'n_variables': 1,
    'n_coefficients': 0,
    'leaf_probs': '0.75,0,0.25,0',
    'max_len': 512,
    'max_int': 5,
    'max_ops': 15,
    'max_ops_G': 15,
    'clean_prefix_expr': True,
    'rewrite_functions': '',
    'tasks': 'prim_fwd',
    'operators': 'add:10,sub:3,mul:10,div:5,sqrt:4,pow2:4,pow3:2,pow4:1,pow5:1,ln:4,exp:4,sin:4,cos:4,tan:4,asin:1,acos:1,atan:1,sinh:1,cosh:1,tanh:1,asinh:1,acosh:1,atanh:1',

    # model parameters
    'cpu': True,  # Set to False if using CUDA
    'emb_dim': 1024,
    'n_enc_layers': 6,
    'n_dec_layers': 6,
    'n_heads': 8,
    'dropout': 0,
    'attention_dropout': 0,
    'sinusoidal_embeddings': False,
    'share_inout_emb': True,
    'reload_model': model_path,
})


We now load the pretrained transformer model using code from the official [SymbolicMathematics repository](https://github.com/facebookresearch/SymbolicMathematics/tree/main).

To support running on CPU (and optionally on GPU), we slightly modify the original `build_modules` function from [`src/model/init.py`](https://github.com/facebookresearch/SymbolicMathematics/blob/main/src/model/init.py) by explicitly setting the device based on availability and user preference.

This function:
- Initializes both **encoder** and **decoder** transformer modules.
- Loads pretrained weights from `params.reload_model`.
- Automatically removes the `'module.'` prefix if weights were saved with `nn.DataParallel`.



In [None]:
device = torch.device("cuda" if torch.cuda.is_available() and not params.cpu else "cpu")

def build_modules_cpu(env, params):
    """
    Load pretrained encoder and decoder modules with CPU/GPU compatibility.
    """
    logger = getLogger()

    # Instantiate encoder and decoder
    modules = {
        'encoder': TransformerModel(params, env.id2word, is_encoder=True, with_output=False),
        'decoder': TransformerModel(params, env.id2word, is_encoder=False, with_output=True),
    }

    # Load pretrained weights
    if params.reload_model != '':
        logger.info(f"Reloading modules from {params.reload_model} ...")
        reloaded = torch.load(params.reload_model, map_location=torch.device('cpu'))
        for name, model in modules.items():
            assert name in reloaded
            state_dict = reloaded[name]
            if all(k.startswith('module.') for k in state_dict):
                state_dict = {k[len('module.'):]: v for k, v in state_dict.items()}
            model.load_state_dict(state_dict)

    # Log parameter count
    for name, model in modules.items():
        logger.info(f"Number of parameters ({name}): {sum(p.numel() for p in model.parameters() if p.requires_grad)}")

    # Move to CUDA if available and not forced to use CPU
    if not params.cpu:
        for model in modules.values():
            model.cuda()

    return modules

In [None]:
from SymbolicMathematics.src.envs import build_env

env = build_env(params)
x = env.local_dict['x']

if device == "cuda":
  print("CUDA")
  modules = build_modules(env, params)
else:
  print("CPU")
  modules = build_modules_cpu(env, params)
encoder = modules['encoder']
decoder = modules['decoder']

## 🎯 Problem 1: Analyze Encoder and Decoder Structure (1 Point)

In this task, you will inspect the structure of the loaded Transformer **encoder** and **decoder**.

Your goals:
1. **Determine the number of attention heads**.
2. **Determine the number of attention layers**.
3. **Determine the total embedding dimension** used by the model.

> 💡 Recall:
> - The total embedding dimension is split across multiple attention heads.
> - You should **verify your answers using actual model parameters**, not just from the `params` dictionary.

### ✅ What to Do

- Use Python to explore the `encoder` and `decoder` objects.
- Print relevant layer configurations.
- Make sure to verify and **justify** your answers with code and comments.

### 📌 Output Format

- **Number of attention heads:** `...`
- **Number of attention layers:** `...`
- **Total embedding dimension:** `...`

Each attention head has embedding size: `total_emb_dim / num_heads`


## 🧪 Problem 2 (4 points total)

### 🧪 Problem 2a: Explore SymPy & Tokenization

In this part, you will get familiar with how the symbolic math library **SymPy** handles expressions, and how the **SymbolicMathematics environment** tokenizes them into prefix notation.

We define a symbolic function \( F \), compute its derivative \( f = F' \), and then convert both expressions to prefix form using the environment's `sympy_to_prefix()` method.

> 🔍 Your task is to experiment by modifying `F_infix` and observe how SymPy and the environment behave. Try functions with:
> - Basic operations (e.g. `x**2 + 3*x`)
> - Trigonometric functions
> - Exponentials and logs
> - Nested function compositions

No submission is needed — just use this to make sure the code flow makes sense before we move on to prediction and evaluation.

In [None]:
from IPython.display import display
import sympy as sp

F_infix = '1/cos(x)'
# F_infix = 'x * cos(x**2) * tan(x)'
# F_infix = 'ln(cos(x + exp(x)) * sin(x**2 + 2) * exp(x) / x)'
# F_infix = 'cos(x**2 * exp(x * cos(x)))'
# F_infix = 'x**2 + x '
# F_infix = '123 * exp(2*x)'
# F_infix = '1'

# Parse and differentiate
F = sp.S(F_infix, locals=env.local_dict)
f = F.diff()


# Convert to prefix notation
F_prefix = env.sympy_to_prefix(F)
f_prefix = env.sympy_to_prefix(f)

# Pretty output
print("Original infix string:")
print(F_infix)

print("\nParsed SymPy expression for F:")
display(F)

print("\nSymbolic derivative f = F':")
display(f)

print("\nPrefix form of F:")
print(F_prefix)

print("\nPrefix form of f:")
print(f_prefix)

### 🧪 Problem 2b: Run the Model to Predict the Integral

Now that we have the function \( f = F' \), we ask the model to predict the original \( F \) given only \( f \).

To do this, we:

1. Create a **prefix-encoded input** to represent the symbolic query “what function has derivative \( f \)?”
2. Encode the input using the **encoder** module.
3. Use **beam search decoding** on the decoder to generate candidate antiderivatives.

In [None]:


x1_prefix = env.clean_prefix(['sub', 'derivative', 'f', 'x', 'x'] + f_prefix)

# Tokenize and encode the input sequence
x1 = torch.LongTensor(
    [env.eos_index] +
    [env.word2id[w] for w in x1_prefix] +
    [env.eos_index]
).view(-1, 1)

len1 = torch.LongTensor([len(x1)])

# If using GPU: uncomment the following
# x1, len1 = to_cuda(x1, len1)

We pass the symbolic input through the **encoder** to produce a representation of the function \( f \).

In [None]:
# Encode the input with the Transformer encoder
with torch.no_grad():
    encoded = encoder('fwd', x=x1, lengths=len1, causal=False).transpose(0, 1)

We use **beam search** to generate multiple likely hypotheses for the original function \( F \).

#### 🧠 What is Beam Search?

Beam search keeps track of the top \( k \) most likely outputs (here, `beam_size = 10`) at every decoding step instead of just one. This allows the model to explore multiple plausible symbolic expressions, improving accuracy in structured tasks like integration.

In [None]:
beam_size = 10

with torch.no_grad():
    _, _, beam = decoder.generate_beam(
        encoded,
        len1,
        beam_size=beam_size,
        length_penalty=1.0,
        early_stopping=1,
        max_len=200
    )

# Check the number of hypotheses generated
assert len(beam) == 1
hypotheses = beam[0].hyp
assert len(hypotheses) == beam_size

Now we:
- Convert each predicted prefix expression into infix and then to a SymPy object.
- Symbolically differentiate the predicted expression.
- Compare it to the original function \( f \).

In [None]:
from sympy import simplify
from IPython.display import display, Math

def display_model_predictions(f, F, hypotheses, env):
    """
    Display model predictions in a formatted mathematical style.

    Parameters:
    - f: SymPy expression (the derivative)
    - F: SymPy expression (the reference antiderivative)
    - hypotheses: list of (score, tensor) pairs returned by beam search
    - env: SymbolicMathematics environment, providing prefix <-> infix conversions
    """
    print("🔢 Input function (f):")
    display(f)

    print("\n📌 Reference antiderivative (F):")
    display(F)

    print("\n🔍 Model Predictions:")

    for score, sent in sorted(hypotheses, key=lambda x: x[0], reverse=True):
        ids = sent[1:].tolist()
        tok = [env.id2word[wid] for wid in ids]

        try:
            hyp = env.prefix_to_infix(tok)
            hyp_sympy = env.infix_to_sympy(hyp)

            var = list(f.free_symbols)[0]
            res = "OK" if simplify(hyp_sympy.diff(var) - f, seconds=1) == 0 else "NO"

            label = f"{score:.5f}  {res}"
            display(Math(rf"\text{{{label}}} \quad \Rightarrow \quad {sp.latex(hyp_sympy)}"))

        except Exception:
            label = f"{score:.5f}  INVALID PREFIX EXPRESSION"
            display(Math(rf"\text{{{label}}} \quad \Rightarrow \quad \text{{{tok}}}"))

In [None]:
display_model_predictions(f, F, hypotheses, env)

### 🧩 Wrapping Up the Prediction Process

To streamline the workflow, we now wrap the entire prediction pipeline — from parsing the function \( F \), computing its derivative \( f \), preparing the input, encoding it, and generating outputs via beam search — into a single helper function.

This will make it easier to test different functions with minimal boilerplate, while still giving access to all intermediate results like \( f \), \( F \), and the list of candidate hypotheses.

In [None]:
def predict_antiderivatives(F_infix, env, encoder, decoder, beam_size=10):
    """
    Given an infix expression for F, compute its derivative f,
    generate prefix input for the model, run beam search and return all results.

    Returns:
        f: sympy expression (the derivative)
        F: sympy expression (the reference antiderivative)
        hypotheses: list of (score, decoded_tensor) from beam search
    """
    import sympy as sp
    import torch

    # Parse the input expression
    F = sp.S(F_infix, locals=env.local_dict)
    var = list(F.free_symbols)[0]
    f = F.diff(var)

    # Convert derivative to prefix and prepare model input
    f_prefix = env.sympy_to_prefix(f)
    x1_prefix = env.clean_prefix(['sub', 'derivative', 'f', 'x', 'x'] + f_prefix)
    x1 = torch.LongTensor(
        [env.eos_index] + [env.word2id[w] for w in x1_prefix] + [env.eos_index]
    ).view(-1, 1)
    len1 = torch.LongTensor([len(x1)])

    # Encode the input
    with torch.no_grad():
        encoded = encoder('fwd', x=x1, lengths=len1, causal=False).transpose(0, 1)

    # Beam search decoding
    with torch.no_grad():
        _, _, beam = decoder.generate_beam(
            encoded,
            len1,
            beam_size=beam_size,
            length_penalty=1.0,
            early_stopping=1,
            max_len=200
        )

    hypotheses = beam[0].hyp
    return f, F, hypotheses

from sympy.parsing.sympy_parser import parse_expr



In [None]:
F_infix = 'x**(-3)'  # You can replace this with any symbolic function
f, F, hypotheses = predict_antiderivatives(F_infix, env, encoder, decoder)

# Step 2: Display predictions
display_model_predictions(f, F, hypotheses, env)

### 🔧 Task: Predict from a Given Derivative Expression (3 points)

In this task, you will implement a function that predicts possible antiderivatives F(x)  for a given derivative f(x), using a Transformer-based sequence-to-sequence model.

The model is trained to operate on prefix expressions of mathematical functions. Your function must:
- Parse an infix string representing  f(x)  into a SymPy expression **without simplifying** it.
- Convert the expression to a prefix format suitable for the model.


In [None]:
def predict_from_derivative(f_infix, env, encoder, decoder, beam_size=10):
    """
    Given an infix expression for the derivative f(x), predict possible antiderivatives F(x).

    Parameters:
        f_infix (str): A string in infix notation (e.g., "x + x**2")
        env: A symbolic environment that provides tokenization and parsing utilities
        encoder: The encoder model
        decoder: The decoder model
        beam_size (int): Number of beams to explore during decoding

    Returns:
        f (sympy expression): Parsed symbolic expression for the derivative f(x)
        None: Placeholder for true F (not known in this setting)
        hypotheses (list): List of (score, decoded_tensor) from beam search
    """
    # <YOUR CODE>
    f = None
    hypotheses = list()
    return f, None, hypotheses

f, _, hypotheses = predict_from_derivative("x+x", env, encoder, decoder)
assert str(f) == "x + x", f"Expected 'x + x', got '{f}'"

display_model_predictions(f, _, hypotheses, env)

### Explain the resusts for x + x. Why are they not the same as for previous attempts? (1 point.)

## 🔍 Problem 3: Inspecting the Attention Mechanism (5 Points total)



Recall the formula for self-attention in a Transformer layer:
$$
\text{Attention}(Q, K, V) = \text{Softmax}\left(\frac{QK^\top}{\sqrt{d_k}}\right)V
$$

Each encoder layer applies this mechanism independently across attention heads.


### 🔍  Extracting Attention Matrices from the Encoder

In order to better understand and visualize what the encoder is focusing on, we define a utility function `get_attention_matrices`. This function uses **forward hooks** in PyTorch to extract the raw attention weights from each attention layer in the encoder during a forward pass.

A **hook** is a special mechanism that allows us to tap into the forward (or backward) pass of a model and access intermediate data such as activations or gradients. Here, we register a **forward hook** on each attention module of the encoder, so that when the encoder processes an input, we can capture and store its internal attention matrices.

This is particularly useful for:
- Visualizing attention patterns over the input sequence
- Understanding how different layers attend to various input tokens
- Interpreting model behavior and debugging training dynamics

The function returns a list of attention matrices, one for each encoder layer.

> **Note**: The encoder is run in evaluation mode (`torch.no_grad()`) to avoid affecting the computation graph and to reduce memory usage.

In [None]:
import matplotlib.pyplot as plt
import seaborn as sns
import numpy as np
import torch
import torch.nn.functional as F
from matplotlib.colors import LinearSegmentedColormap

def get_attention_matrices(encoder, input_ids, input_lengths, causal=False):
    """
    Extract attention matrices from encoder for visualization.

    Returns:
        encoder_attentions: List of attention matrices from each encoder layer
    """
    encoder_attentions = []

    # Register hooks to capture attention matrices
    def encoder_attention_hook(module, input, output):
        # Extract attention weights from the output
        attn_weights = module.outputs
        encoder_attentions.append(attn_weights.detach())

    # Add hooks to all attention modules in encoder
    hooks = []
    for layer in encoder.attentions:
        hook = layer.register_forward_hook(encoder_attention_hook)
        hooks.append(hook)

    # Forward pass through encoder to capture attention
    with torch.no_grad():
        encoder_output = encoder('fwd', x=input_ids, lengths=input_lengths, causal=causal)

    # Remove hooks
    for hook in hooks:
        hook.remove()

    return encoder_attentions


In [None]:
def visualize_attention(attention_matrix, tokens_in, tokens_out=None, layer_name="", f_expr=None, F_expr=None, figsize=(16, 10)):
    """
    Visualize attention heads in a 4x2 grid with full-sized readable plots.

    Args:
        attention_matrix: Tensor of shape [heads, tgt_len, src_len]
        tokens_in: Tokens on x-axis (input sequence)
        tokens_out: Tokens on y-axis (output sequence)
        layer_name: Layer title
        f_expr: Optional sympy expression for derivative
        F_expr: Optional sympy expression for integral
        figsize: Size for each full grid (e.g. 16x10 for 8 heads)
    """
    if tokens_out is None:
        tokens_out = tokens_in

    n_heads = attention_matrix.shape[0]
    rows, cols = n_heads // 2, 2  # grid layout: 4x2 for 8 heads

    fig, axes = plt.subplots(rows, cols, figsize=figsize)
    axes = axes.flatten()

    for h in range(n_heads):
        ax = axes[h]
        data = attention_matrix[h].cpu().numpy()

        sns.heatmap(
            data,
            ax=ax,
            cmap="Blues",
            annot=True,
            fmt=".2f",
            annot_kws={"size": 9},
            xticklabels=tokens_in,
            yticklabels=tokens_out,
            cbar=False,
            linewidths=0.5,
            linecolor='gray',
            vmin=0,
            vmax=1
        )
        ax.set_title(f"Head {h + 1}", fontsize=12)
        ax.set_xlabel("Source")
        ax.set_ylabel("Target")
        ax.tick_params(axis='x', labelrotation=45)
        ax.tick_params(axis='both', labelsize=10)

    # Hide extra axes if any
    for h in range(n_heads, len(axes)):
        axes[h].axis('off')

    # Title text
    title = f"{layer_name}"
    if f_expr is not None and F_expr is not None:
        title += f"\n$f = {sp.latex(f_expr)}$,\n$F = {sp.latex(F_expr)}$"

    plt.suptitle(title, y=1.03, fontsize=16)
    plt.tight_layout()
    plt.subplots_adjust(top=0.88)
    return fig


In [None]:
def visualize_all_encoder_layers(encoder_attentions, input_tokens, figsize=(15, 20)):
    """
    Visualize attention for each layer in the encoder
    """
    for i, attn in enumerate(encoder_attentions):
        # attn shape: [batch_size, num_heads, seq_len, seq_len]
        # For visualization, we take the first batch item
        attn_matrix = attn[0]  # [num_heads, seq_len, seq_len]

        fig = visualize_attention(
            attn_matrix,
            input_tokens,
            layer_name=f"Encoder Layer {i+1}",
            figsize=figsize
        )
        plt.show()

In [None]:
def visualize_model_attention(f_prefix, env, encoder, decoder):
    # Prepare input for visualization
    x1_prefix = env.clean_prefix(['sub', 'derivative', 'f', 'x', 'x'] + f_prefix)
    x1 = torch.LongTensor(
        [env.eos_index] +
        [env.word2id[w] for w in x1_prefix] +
        [env.eos_index]
    ).view(-1, 1)
    len1 = torch.LongTensor([len(x1)])

    if not params.cpu:
        x1 = x1.cuda()
        len1 = len1.cuda()

    # Get attention matrices
    encoder_attentions = get_attention_matrices(encoder, x1, len1)

    # Prepare token labels for visualization
    input_tokens = ['<EOS>'] + x1_prefix + ['<EOS>']

    # Visualize attention matrices
    visualize_all_encoder_layers(encoder_attentions, input_tokens)

    return encoder_attentions

In [None]:
from IPython.display import display, Math

def run_attention_visualizations(F_infix, env, encoder, decoder):
    """
    Parse a function F, compute its derivative f, tokenize it,
    and visualize encoder attention for the model input based on f.

    Returns:
        encoder_attentions: List of attention matrices
        f: SymPy expression (F')
        F: SymPy expression (original)
        f_prefix: list of tokens (prefix form of f)
    """
    print("🔍 Analyzing symbolic integration task")
    print(f"Raw input (F_infix): {F_infix}")

    # Parse function F and compute its derivative f
    F = sp.S(F_infix, locals=env.local_dict)
    var = list(F.free_symbols)[0]
    f = F.diff(var)

    print("\n📌 Parsed expressions:")
    display(Math(rf"F_(x) = {sp.latex(F)}"))
    display(Math(rf"f(x) = \frac{{d}}{{dx}}F(x) = {sp.latex(f)}"))

    # Tokenize f to prefix form for the model
    f_prefix = env.sympy_to_prefix(f)
    print(f"\n🧮 Prefix input to the model (f'): {f_prefix}")

    # Visualize attention matrices
    encoder_attentions = visualize_model_attention(f_prefix, env, encoder, decoder)

    return encoder_attentions, f, F, f_prefix

In [None]:
encoder.eval()
decoder.eval()
attentions_example=run_attention_visualizations('sin(31*exp(-2*x))', env, encoder, decoder)

### 🔍 Task: Interpret Transformer Attention Heads (5 Points)

In this assignment, you will explore and interpret what some attention heads are doing in a pretrained transformer model that attempts to generate antiderivatives of symbolic functions.

---

### 🎯 Objective:

Choose **4 attention heads** from any layers and try to hypothesize what each of them might be focusing on.

For each head, you must:
- Specify the **layer number** and **head number**
- Provide a **plausible explanation** of what the head might be attending to do
- Use your judgment to back the claim — even speculative insights are welcome!

⚠️ Keep in mind:  
Most heads are not cleanly interpretable — and that’s okay. The goal is to practice the **process** of inspection and reasoning.

---

### 🧪 Default setup:

By default, you can use the input function:  
$$
\sin(31 \cdot e^{-2x})
$$

This function is complex enough to exhibit interesting symbolic structure.
Around first 3 layers you can definetly find insightfull layers.
---

### 🛠️ Notes:

- You are welcome (and encouraged!) to try **your own function inputs** instead of the default.
- If you do, make sure to include **all required outputs and visuals** to support your interpretation (e.g., tokenized inputs, attention weights).
- Try to reflect on whether the heads are focusing on:
  - Parenthesis structure?
  - Operator precedence?
  - Specific mathematical symbols like `x`, `e`, `sin`?
  - Repetitive patterns?
  - Numbers
  - Signs

---

### ✅ Submission format (for each head):

Please follow this format for each of your 4 analyses:

```
Layer: <layer_number>
Head: <head_number>

🧠 Interpretation:
<your explanation>

🔍 Evidence:
<describe what makes you think that – show attention maps or token examples if possible>
```

---

Happy decoding! 🧠🔬

## 🔁 Bonus Task. Cross-Attention in Transformer Decoders (5 points total)

In contrast to **self-attention**, which attends to tokens within the same sequence, **cross-attention** occurs in the decoder and allows it to attend to the encoder's output representations.

This mechanism is crucial for tasks like translation, summarization, or symbolic integration, where the decoder needs to align its generated output with specific parts of the input.

### 🔍 What We'll Do

In this section, we will:
- Capture the **cross-attention matrices** from all decoder layers.
- Use forward hooks on the decoder's cross-attention modules (`encoder_attn`) to collect attention weights during autoregressive generation.
- For each decoder layer, we will obtain an attention tensor of shape:
$$
\texttt{[batch, heads, target_len, source_len]}
$$
These matrices allow us to analyze how the decoder focuses on different parts of the input sequence while generating each token.

We wrap this functionality in the `get_cross_attention_matrices` function.


### 🧪 Implementation Task: Extracting Cross-Attention Weights (4 points)

To analyze how the decoder attends to the encoder's outputs, complete the function `get_cross_attention_matrices`.

This function will:
- Register **forward hooks** on each decoder cross-attention (`encoder_attn`) block to capture attention weights during generation.
- Run the encoder and then the decoder to generate a sequence.
- Collect the cross-attention tensors across decoder layers.

The function should return:
- `cross_attentions`: A list with one tensor per decoder layer of shape `[batch, heads, target_len, source_len]`.
- `generated`: The generated sequence (token IDs).
- `gen_len`: The length(s) of generated sequences.

The missing parts are clearly marked as `<YOUR CODE>` — fill them in using your knowledge of encoder-decoder workflows in PyTorch.


This task is worth **5 points**.

In [None]:
def get_cross_attention_matrices(encoder, decoder, input_ids, input_lengths, causal=False):
    """
    Extract cross-attention weights from decoder during generation.

    Args:
        encoder: Transformer encoder
        decoder: Transformer decoder
        input_ids: LongTensor of shape [seq_len, batch]
        input_lengths: LongTensor of shape [batch]
        causal: Whether to use causal attention (not needed here)

    Returns:
        cross_attentions: List of tensors, one per decoder layer,
                          each of shape [batch, heads, tgt_len, src_len]
        generated: Tensor of generated token ids
        gen_len: Lengths of generated sequences
    """
    cross_attn_chunks = []

    def cross_attention_hook(module, input, output):
        attn_weights = module.outputs  # shape: [batch, heads, 1, src_len] per timestep
        cross_attn_chunks.append(attn_weights.detach())

    # Register hooks on all cross-attention layers in decoder
    hooks = [layer.register_forward_hook(cross_attention_hook) for layer in decoder.encoder_attn]

    # Run encoder and decoder
    with torch.no_grad():
        # <YOUR CODE>

    # Remove hooks
    for hook in hooks:
        hook.remove()

    # <YOUR CODE>

    return cross_attentions, generated, gen_len

In [None]:
from sympy import simplify
from IPython.display import display, Math

def visualize_model_cross_attention(f_prefix, env, encoder, decoder):
    # Prepare input for visualization
    x1_prefix = env.clean_prefix(['sub', 'derivative', 'f', 'x', 'x'] + f_prefix)
    x1 = torch.LongTensor(
        [env.eos_index] +
        [env.word2id[w] for w in x1_prefix] +
        [env.eos_index]
    ).view(-1, 1)
    len1 = torch.LongTensor([len(x1)])

    if not params.cpu:
        x1 = x1.cuda()
        len1 = len1.cuda()

    # Get attention matrices and generated output
    encoder_cross_attentions, generated, gen_len = get_cross_attention_matrices(encoder, decoder, x1, len1)

    # Prepare input token labels
    input_tokens = ['<EOS>'] + x1_prefix + ['<EOS>']

    # Prepare output tokens for visualization (1 per decoder step)
    output_tokens = []
    for i in range(generated.shape[0]):
        output_tokens.append([env.id2word[idx.item()] for idx in generated[i]][0])

    print("\n🔮 Predicted Function (from first beam):")
    pred_ids = generated[:, 0].tolist()
    pred_tokens = [env.id2word[wid] for wid in pred_ids if env.id2word[wid] not in ['<s>', '<EOS>', '<pad>']]

    try:
        pred_infix = env.prefix_to_infix(pred_tokens)
        pred_sympy = env.infix_to_sympy(pred_infix)

        # Try to validate against f
        f = sp.S(env.prefix_to_infix(f_prefix), locals=env.local_dict)
        var = list(f.free_symbols)[0]
        res = "OK" if simplify(pred_sympy.diff(var) - f, seconds=1) == 0 else "NO"

        label = f"{res}"
        display(Math(rf"\text{{{label}}} \quad \Rightarrow \quad {sp.latex(pred_sympy)}"))

    except Exception:
        display(Math(rf"\text{{INVALID PREFIX EXPRESSION}} \quad \Rightarrow \quad \text{{{pred_tokens}}}"))

    # === Visualize cross attention ===
    visualize_all_cross_layers(encoder_cross_attentions, input_tokens, output_tokens)

    return encoder_cross_attentions

In [None]:
def visualize_all_cross_layers(cross_attentions, input_tokens, output_tokens, figsize=(15, 20)):
    """
    Visualize attention for each layer in the encoder and decoder
    """

    for i, attn in enumerate(cross_attentions):
        # attn shape: [batch_size, num_heads, seq_len, seq_len]
        # For visualization, we take the first batch item
        attn_matrix = attn[0]  # [num_heads, seq_len, seq_len]

        fig = visualize_attention(
            attn_matrix,
            input_tokens,
            output_tokens,
            layer_name=f"Encoder Layer {i+1}",
            figsize=figsize
        )
        plt.show()

from IPython.display import display, Math

def run_cross_attention_visualizations(F_infix, env, encoder, decoder):
    """
    Visualize cross-attention maps for decoder attending to encoder representations,
    given a symbolic integration task.
    """
    print("🔍 Running cross-attention visualization")
    print(f"Raw input (F_infix): {F_infix}")

    # Parse and differentiate
    F = sp.S(F_infix, locals=env.local_dict)
    var = list(F.free_symbols)[0]
    f = F.diff(var)

    # Show parsed math
    print("\n📌 Parsed symbolic expressions:")
    display(Math(rf"F(x) = {sp.latex(F)}"))
    display(Math(rf"f(x) = \frac{{d}}{{dx}}F(x) = {sp.latex(f)}"))

    # Tokenize for input
    f_prefix = env.sympy_to_prefix(f)
    print(f"\n🧮 Tokenized prefix input to model (f'): {f_prefix}")

    # Visualize cross-attention
    cross_attentions = visualize_model_cross_attention(f_prefix, env, encoder, decoder)

    return cross_attentions

### ➕ Bonus Task: Cross-Attention Exploration (1 pts)

Investigate the **cross-attention** mechanism between the encoder and decoder.

Your task:

1. **Pick one cross-attention head** and describe what kind of encoder tokens the decoder attends to at that position.
2. **Give a concrete example** (e.g., decoder token "sin" attends to what? Show attention weights or describe the pattern).

💡 Tip: Try it on the default input ($\sin(31 \cdot e^{-2x})$) or any other expression you analyzed earlier.

🎯 Focus on intuition — even rough insights are useful!

In [None]:
attentions_example=run_cross_attention_visualizations('sin(31*exp(-2*x))', env, encoder, decoder)