# Kernel Regression as In-Context Learning

When a transformer receives examples $(x_1, y_1), \ldots, (x_n, y_n)$ in its prompt and predicts $y$ for a new query $x_q$ — **without any weight update** — it is doing *in-context learning* (ICL).

This notebook shows that ICL is structurally **kernel regression**:

| In-Context Learning | Kernel Regression |
|---|---|
| Context examples $(x_i, y_i)$ | Training set |
| Query $x_q$ | Test point |
| Attention weights $\text{softmax}(x_q \cdot x_i)$ | Kernel similarities $k(x_q, x_i)$ |
| Output = weighted sum of $y_i$ | Prediction = weighted sum of $y_i$ |
| No weight update | No weight update |

We'll build this connection in three steps:

1. **Ridge → Dual form** — rewrite ridge regression so data appears only through dot products
2. **Dual → Kernel regression** — replace dot products with a kernel (the kernel trick)
3. **Kernel regression = ICL** — show a single attention head is literally kernel regression

In [None]:
import torch
import matplotlib.pyplot as plt

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
dtype = torch.float64
print(f"Using device: {device}")

## Step 1: Ridge Regression — Primal and Dual

Ridge regression minimises $\|Xw - y\|^2 + \lambda\|w\|^2$. Two equivalent solutions:

| | Primal | Dual |
|---|---|---|
| **Solve for** | $w \in \mathbb{R}^d$ | $\alpha \in \mathbb{R}^n$ |
| **Formula** | $w = (X^\top X + \lambda I_d)^{-1} X^\top y$ | $\alpha = (XX^\top + \lambda I_n)^{-1} y$ |
| **Predict** | $\hat{y} = X_\text{new}\, w$ | $\hat{y} = X_\text{new} X^\top \alpha$ |
| **Inverts** | $(d \times d)$ matrix | $(n \times n)$ matrix |

They are the same (Woodbury identity). But in the dual form, $X$ only appears as dot products $XX^\top$.

In [None]:
class RidgePrimal:
    """w = (X^T X + λI)^{-1} X^T y"""
    def __init__(self, lam=1.0):
        self.lam = lam
    def fit(self, X, y):
        d = X.shape[1]
        A = X.T @ X + self.lam * torch.eye(d, device=X.device, dtype=X.dtype)
        self.w = torch.linalg.solve(A, X.T @ y)
        return self
    def predict(self, X):
        return X @ self.w

class RidgeDual:
    """α = (XX^T + λI)^{-1} y  — same answer, data appears only as dot products"""
    def __init__(self, lam=1.0):
        self.lam = lam
    def fit(self, X, y):
        self.X_train = X
        n = X.shape[0]
        K = X @ X.T
        self.alpha = torch.linalg.solve(K + self.lam * torch.eye(n, device=X.device, dtype=X.dtype), y)
        return self
    def predict(self, X):
        return (X @ self.X_train.T) @ self.alpha

In [None]:
# Verify: primal and dual give the same predictions
torch.manual_seed(0)
X = torch.randn(50, 3, device=device, dtype=dtype)
y = torch.randn(50, 1, device=device, dtype=dtype)
X_new = torch.randn(10, 3, device=device, dtype=dtype)

p1 = RidgePrimal(lam=1.0).fit(X, y).predict(X_new)
p2 = RidgeDual(lam=1.0).fit(X, y).predict(X_new)
print(f"Max |primal - dual|: {(p1 - p2).abs().max():.2e}  (should be ~0)")

## Step 2: Kernel Ridge Regression

In the dual form, data only appears through the Gram matrix $K_{ij} = x_i \cdot x_j$.

**The kernel trick:** replace the dot product with any kernel function $k(x_i, x_j)$:

$$\alpha = (K + \lambda I)^{-1} y, \qquad \hat{y}_q = \sum_i k(x_q, x_i)\, \alpha_i$$

The algorithm is identical to dual ridge — just a different Gram matrix.

In [None]:
def rbf_kernel(X1, X2, sigma=1.0):
    """k(x, x') = exp(-||x - x'||^2 / 2σ^2)"""
    sq1 = (X1 ** 2).sum(dim=1, keepdim=True)
    sq2 = (X2 ** 2).sum(dim=1, keepdim=True)
    dist_sq = sq1 - 2.0 * X1 @ X2.T + sq2.T
    return torch.exp(-dist_sq / (2.0 * sigma ** 2))

class KernelRidge:
    """α = (K + λI)^{-1} y  with arbitrary kernel."""
    def __init__(self, kernel_fn=rbf_kernel, lam=1.0, **kw):
        self.kernel_fn, self.lam, self.kw = kernel_fn, lam, kw
    def fit(self, X, y):
        self.X_train = X
        n = X.shape[0]
        K = self.kernel_fn(X, X, **self.kw)
        self.alpha = torch.linalg.solve(K + self.lam * torch.eye(n, device=X.device, dtype=X.dtype), y)
        return self
    def predict(self, X):
        return self.kernel_fn(X, self.X_train, **self.kw) @ self.alpha

## Step 3: Kernel Regression = In-Context Learning

Now the key connection. Consider a **single attention head** that receives context examples and a query:

$$\text{Attention}(x_q, \{x_i, y_i\}) = \sum_i \underbrace{\frac{\exp(x_q \cdot x_i)}{\sum_j \exp(x_q \cdot x_j)}}_{\text{softmax attention weight}} \cdot y_i$$

This is the **Nadaraya-Watson kernel estimator** with the softmax kernel:

$$\hat{y}_q = \frac{\sum_i k(x_q, x_i)\, y_i}{\sum_j k(x_q, x_j)}$$

The structure is identical:
- **Keys** $= x_i$ (context inputs)
- **Values** $= y_i$ (context labels)
- **Query** $= x_q$ (the new input)
- **Attention** = normalized kernel similarity
- **Output** = similarity-weighted average of values

**No weights are updated.** The prediction comes entirely from the context — just like kernel regression uses only the training set at test time.

Below we implement both and compare them directly.

In [None]:
def nadaraya_watson(X_ctx, y_ctx, X_query, kernel_fn, **kw):
    """Nadaraya-Watson: ŷ = Σ k(xq, xi) yi / Σ k(xq, xj)
    
    This is what a single attention head computes.
    """
    K = kernel_fn(X_query, X_ctx, **kw)   # (n_query, n_ctx) — like QK^T
    weights = K / K.sum(dim=1, keepdim=True)  # normalize — like softmax
    return weights @ y_ctx                    # weighted sum of values


def attention_icl(X_ctx, y_ctx, X_query):
    """Single-head attention in-context learning.
    
    Keys = X_ctx, Values = y_ctx, Query = X_query.
    Exactly the transformer ICL mechanism (one head, no projections).
    """
    logits = X_query @ X_ctx.T                # (n_query, n_ctx) — QK^T
    weights = torch.softmax(logits, dim=1)    # softmax attention
    return weights @ y_ctx                    # weighted sum of values

## Demo: Noisy Sine Wave

The "context" is $(x_i, y_i)$ pairs from a noisy $\sin(x)$. We predict on new query points using:
1. **Ridge** (primal) — fits a line, can't capture the curve
2. **KRR** (RBF kernel) — kernel regression with ridge regularization
3. **Nadaraya-Watson** (RBF) — kernel regression with normalization (= attention)
4. **Attention ICL** — a literal single attention head over the context

In [None]:
torch.manual_seed(42)

# Context examples (= prompt examples in ICL)
n_ctx = 200
X_ctx = torch.linspace(-3, 3, n_ctx, device=device, dtype=dtype).unsqueeze(1)
y_ctx = torch.sin(X_ctx) + 0.2 * torch.randn(n_ctx, 1, device=device, dtype=dtype)

# Query points
n_query = 100
X_query = torch.linspace(-3, 3, n_query, device=device, dtype=dtype).unsqueeze(1)
y_true  = torch.sin(X_query)

# 1. Ridge (primal)
pred_ridge = RidgePrimal(lam=1.0).fit(X_ctx, y_ctx).predict(X_query)

# 2. Kernel Ridge Regression
pred_krr = KernelRidge(kernel_fn=rbf_kernel, lam=0.01, sigma=0.5).fit(X_ctx, y_ctx).predict(X_query)

# 3. Nadaraya-Watson (= attention with RBF kernel)
pred_nw = nadaraya_watson(X_ctx, y_ctx, X_query, rbf_kernel, sigma=0.5)

# 4. Attention ICL (softmax over raw dot products)
pred_attn = attention_icl(X_ctx, y_ctx, X_query)

mse = lambda a, b: ((a - b) ** 2).mean().item()
print("MSE on query points:")
print(f"  Ridge (linear):          {mse(y_true, pred_ridge):.6f}")
print(f"  KRR (RBF):               {mse(y_true, pred_krr):.6f}")
print(f"  Nadaraya-Watson (RBF):   {mse(y_true, pred_nw):.6f}")
print(f"  Attention ICL (softmax): {mse(y_true, pred_attn):.6f}")

In [None]:
fig, axes = plt.subplots(1, 4, figsize=(18, 4), sharex=True, sharey=True)

xc = X_ctx.cpu().numpy().ravel()
yc = y_ctx.cpu().numpy().ravel()
xq = X_query.cpu().numpy().ravel()
yt = y_true.cpu().numpy().ravel()

titles = ["Ridge (linear)", "KRR (RBF kernel)", "Nadaraya-Watson (RBF)", "Attention ICL (softmax)"]
preds  = [pred_ridge, pred_krr, pred_nw, pred_attn]

for ax, title, pred in zip(axes, titles, preds):
    ax.scatter(xc, yc, s=5, alpha=0.3, color="gray", label="context")
    ax.plot(xq, yt, "k--", lw=1, label="true sin(x)")
    ax.plot(xq, pred.cpu().numpy().ravel(), "r-", lw=2, label="predicted")
    ax.set_title(title, fontsize=11)
    ax.legend(fontsize=7)

fig.suptitle("Kernel regression as in-context learning", fontsize=14)
plt.tight_layout()
plt.show()

## Visualizing the Attention / Kernel Weights

For a single query point, we can visualize *which context examples contribute* to the prediction. This is the attention pattern — or equivalently, the kernel similarity profile.

In [None]:
# Pick a query point
q_idx = 65  # somewhere on the right slope
xq_single = X_query[q_idx:q_idx+1]

# Kernel weights (Nadaraya-Watson)
K_nw = rbf_kernel(xq_single, X_ctx, sigma=0.5)
w_nw = (K_nw / K_nw.sum()).cpu().numpy().ravel()

# Attention weights
logits = xq_single @ X_ctx.T
w_attn = torch.softmax(logits, dim=1).cpu().numpy().ravel()

fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(12, 3.5))

for ax, w, title in [(ax1, w_nw, "RBF kernel weights"), (ax2, w_attn, "Softmax attention weights")]:
    ax.bar(xc, w, width=0.04, color="steelblue", alpha=0.7)
    ax.axvline(xq_single.item(), color="red", ls="--", lw=2, label=f"query = {xq_single.item():.2f}")
    ax.set_title(title, fontsize=11)
    ax.set_xlabel("context x")
    ax.set_ylabel("weight")
    ax.legend(fontsize=9)

fig.suptitle("Which context examples matter for the prediction?", fontsize=13)
plt.tight_layout()
plt.show()

## Takeaway So Far

Raw attention ICL (softmax over raw dot products) is a **bad kernel** — it doesn't know what "similar" means for this problem.

A real transformer doesn't use raw dot products. It learns projection matrices $W_Q, W_K, W_V$ during **pretraining** on many tasks. These projections define a **learned kernel**:

$$k_\theta(x_q, x_i) = \text{softmax}\!\left(\frac{(x_q W_Q)(x_i W_K)^\top}{\sqrt{d_k}}\right)$$

This is the key: **pretraining = learning a good kernel**. Then at inference time, the learned kernel does ICL on new tasks it has never seen.

## Step 4: Pre-trained Kernel for ICL

We'll train a single attention head on a **distribution of tasks** (random sinusoids with varying amplitude, frequency, phase). Each training step:

1. Sample a random function $f$
2. Generate context $(x_i, f(x_i) + \text{noise})$ and query $(x_q, f(x_q))$
3. Predict $\hat{y}_q$ via attention with learned $W_Q, W_K, W_V$
4. Backprop MSE loss

After pretraining, we freeze the weights and test on a **new task never seen during training**. The learned kernel should do much better ICL than raw dot products.

In [None]:
class PretrainedAttentionICL(torch.nn.Module):
    """Single attention head with learnable projections.
    
    Architecture mirrors kernel regression:
      - Keys from x only    → the kernel measures similarity in input space
      - Values from y only  → the output is a weighted sum of context labels
      - Query from x only   → the query point
    
    This is a learned Nadaraya-Watson estimator.
    """
    def __init__(self, input_dim=1, d_model=64):
        super().__init__()
        # Separate embeddings for x (used in Q, K) and y (used in V)
        self.embed_x = torch.nn.Sequential(
            torch.nn.Linear(input_dim, d_model),
            torch.nn.ReLU(),
            torch.nn.Linear(d_model, d_model),
        )
        self.embed_y = torch.nn.Sequential(
            torch.nn.Linear(1, d_model),
            torch.nn.ReLU(),
            torch.nn.Linear(d_model, d_model),
        )
        
        # Q, K operate on x-embeddings → define the learned kernel
        self.W_Q = torch.nn.Linear(d_model, d_model, bias=False)
        self.W_K = torch.nn.Linear(d_model, d_model, bias=False)
        # V operates on y-embeddings → read out labels
        self.W_V = torch.nn.Linear(d_model, d_model, bias=False)
        
        self.out = torch.nn.Linear(d_model, 1)
        self.scale = d_model ** 0.5
    
    def forward(self, X_ctx, y_ctx, X_query):
        """
        X_ctx:   (batch, n_ctx, input_dim)
        y_ctx:   (batch, n_ctx, 1)
        X_query: (batch, n_query, input_dim)
        Returns: (batch, n_query, 1)
        """
        # Keys from context x, Values from context y, Queries from query x
        K = self.W_K(self.embed_x(X_ctx))       # (batch, n_ctx, d_model)
        V = self.W_V(self.embed_y(y_ctx))        # (batch, n_ctx, d_model)
        Q = self.W_Q(self.embed_x(X_query))      # (batch, n_query, d_model)
        
        # Attention = learned kernel
        weights = torch.softmax(Q @ K.transpose(-2, -1) / self.scale, dim=-1)
        out = weights @ V                        # (batch, n_query, d_model)
        
        return self.out(out)                     # (batch, n_query, 1)

### Task distribution for pretraining

Each task is a random sinusoid: $f(x) = a \sin(\omega x + \phi)$ with random amplitude $a$, frequency $\omega$, and phase $\phi$. The model never sees the same function twice — it must learn to **read the context** to figure out which function it's dealing with.

In [None]:
def sample_task_batch(batch_size, n_ctx, n_query, noise_std=0.2):
    """Sample a batch of random sinusoid tasks.
    
    Each task:  f(x) = a * sin(ω*x + φ)  with random a, ω, φ.
    Returns context (X_ctx, y_ctx) and query (X_query, y_query).
    """
    # Random function parameters per task
    a     = 0.5 + 1.5 * torch.rand(batch_size, 1, 1, device=device)    # amplitude [0.5, 2]
    omega = 0.5 + 2.0 * torch.rand(batch_size, 1, 1, device=device)    # frequency [0.5, 2.5]
    phi   = 2 * 3.14159 * torch.rand(batch_size, 1, 1, device=device)  # phase [0, 2π]
    
    # Sample x locations uniformly
    X_all = 6.0 * torch.rand(batch_size, n_ctx + n_query, 1, device=device) - 3.0  # [-3, 3]
    y_all = a * torch.sin(omega * X_all + phi)
    
    # Split into context and query
    X_ctx   = X_all[:, :n_ctx]
    y_ctx   = y_all[:, :n_ctx] + noise_std * torch.randn_like(y_all[:, :n_ctx])
    X_query = X_all[:, n_ctx:]
    y_query = y_all[:, n_ctx:]  # clean targets for query
    
    return X_ctx, y_ctx, X_query, y_query

### Pretraining loop

Train the attention head on many random tasks. This is analogous to pretraining a transformer on diverse data — the model learns a kernel that works across tasks.

In [None]:
torch.manual_seed(42)

model = PretrainedAttentionICL(input_dim=1, d_model=64).to(device).float()
optimizer = torch.optim.Adam(model.parameters(), lr=1e-3)

n_steps = 5000
batch_size = 64
n_ctx_train = 40
n_query_train = 10

losses = []
for step in range(n_steps):
    X_c, y_c, X_q, y_q = sample_task_batch(batch_size, n_ctx_train, n_query_train)
    
    pred = model(X_c.float(), y_c.float(), X_q.float())
    loss = ((pred - y_q.float()) ** 2).mean()
    
    optimizer.zero_grad()
    loss.backward()
    optimizer.step()
    losses.append(loss.item())
    
    if step == 0 or (step + 1) % 1000 == 0:
        print(f"  Step {step+1:5d}  loss: {loss.item():.4f}")

plt.figure(figsize=(8, 3))
plt.plot(losses, alpha=0.3, color="steelblue")
window = 100
smoothed = [sum(losses[max(0,i-window):i+1])/len(losses[max(0,i-window):i+1]) for i in range(len(losses))]
plt.plot(smoothed, color="darkblue", lw=2)
plt.xlabel("Training step")
plt.ylabel("MSE loss")
plt.title("Pretraining: learning a kernel across many tasks")
plt.tight_layout()
plt.show()

### Evaluation on a new task

Now we freeze the model and test on **sin(x)** — the same task from earlier. The model has never seen this specific function during pretraining, but it has learned a kernel that can read any sinusoidal context.

In [None]:
# Same test data as before
torch.manual_seed(42)
n_ctx = 200
X_ctx = torch.linspace(-3, 3, n_ctx, device=device, dtype=dtype).unsqueeze(1)
y_ctx = torch.sin(X_ctx) + 0.2 * torch.randn(n_ctx, 1, device=device, dtype=dtype)
n_query = 100
X_query = torch.linspace(-3, 3, n_query, device=device, dtype=dtype).unsqueeze(1)
y_true  = torch.sin(X_query)

# Run pre-trained model (no weight updates — pure ICL)
model.eval()
with torch.no_grad():
    pred_pretrained = model(
        X_ctx.unsqueeze(0).float(),
        y_ctx.unsqueeze(0).float(),
        X_query.unsqueeze(0).float()
    ).squeeze(0).double()

# Compare all methods
pred_ridge = RidgePrimal(lam=1.0).fit(X_ctx, y_ctx).predict(X_query)
pred_krr   = KernelRidge(kernel_fn=rbf_kernel, lam=0.01, sigma=0.5).fit(X_ctx, y_ctx).predict(X_query)
pred_nw    = nadaraya_watson(X_ctx, y_ctx, X_query, rbf_kernel, sigma=0.5)
pred_attn  = attention_icl(X_ctx, y_ctx, X_query)

mse = lambda a, b: ((a - b) ** 2).mean().item()
print("MSE on query points:")
print(f"  Ridge (linear):               {mse(y_true, pred_ridge):.6f}")
print(f"  Attention ICL (raw, no train): {mse(y_true, pred_attn):.6f}")
print(f"  Nadaraya-Watson (RBF):         {mse(y_true, pred_nw):.6f}")
print(f"  KRR (RBF):                     {mse(y_true, pred_krr):.6f}")
print(f"  Attention ICL (pre-trained):   {mse(y_true, pred_pretrained):.6f}  ← learned kernel")

In [None]:
fig, axes = plt.subplots(1, 5, figsize=(22, 4), sharex=True, sharey=True)

xc = X_ctx.cpu().numpy().ravel()
yc = y_ctx.cpu().numpy().ravel()
xq = X_query.cpu().numpy().ravel()
yt = y_true.cpu().numpy().ravel()

titles = ["Ridge (linear)", "Attention ICL\n(raw)", "Nadaraya-Watson\n(RBF)", "KRR (RBF)", "Attention ICL\n(pre-trained)"]
preds  = [pred_ridge, pred_attn, pred_nw, pred_krr, pred_pretrained]

for ax, title, pred in zip(axes, titles, preds):
    ax.scatter(xc, yc, s=5, alpha=0.3, color="gray", label="context")
    ax.plot(xq, yt, "k--", lw=1, label="true sin(x)")
    ax.plot(xq, pred.cpu().numpy().ravel(), "r-", lw=2, label="predicted")
    ax.set_title(title, fontsize=10)
    ax.legend(fontsize=7)

fig.suptitle("Pre-trained kernel learns to do in-context learning", fontsize=14)
plt.tight_layout()
plt.show()

## Full Summary

### 1-D results (sin(x))

| Method | Kernel | MSE |
|---|---|---|
| Ridge | Linear | 0.1713 |
| Attention ICL (raw) | Softmax dot product | 0.0806 |
| Nadaraya-Watson | RBF | 0.0147 |
| GA ICL (penalty) | $\beta_1\cos\theta - \beta_2\sin\theta$ | 0.0082 |
| Cayley ICL | $-\theta^2/(2T^2)$ | 0.0074 |
| KRR (RBF) | RBF | 0.0042 |
| Pre-trained attention | Learned softmax | 0.0039 |
| Cosine ICL | $\cos(\theta)/T$ | 0.0038 |

### High-D results ($x \in \mathbb{R}^{10}$ with correlated feature groups)

Task: 3 of 10 features share a hidden latent $z$; target $y = a\sin(bz + \phi)$. The model must discover which features belong to the group from context alone. Evaluated on 200 held-out tasks.

| Row Kernel | Row-only | Col+Row | Improvement |
|---|---|---|---|
| Softmax | 0.0997 | 0.0780 | +21.8% |
| Cosine | 0.0864 | 0.0718 | +16.9% |
| Cayley | 0.0861 | 0.0718 | +16.6% |
| GA | 0.0862 | 0.0720 | +16.4% |
| *KRR (RBF)* | *0.1165* | — | — |

### Key insights

1. **Ridge → Dual → Kernel** — the algebraic path from familiar to powerful
2. **Kernel regression = ICL** — attention over context IS kernel regression
3. **The kernel matters** — Cosine, Cayley ($\theta$), and GA kernels each bring different geometric structure. In 1-D, Cosine ICL (0.0038) matches the learned kernel and beats hand-tuned KRR (0.0042)
4. **Pretraining = learning the kernel** — learned $W_Q, W_K, W_V$ adapt to the task distribution
5. **GA row kernels outperform softmax** — in higher dimensions, all three geometric kernels (Cosine 0.086, Cayley 0.086, GA 0.086) beat softmax (0.100) for row-only ICL
6. **Column attention helps consistently** — GA column attention improves every row kernel by 16–22%, discovering correlated feature groups from context data
7. **Column attention + GA kernels = best overall** — Col+Cosine and Col+Cayley (0.072) beat all other methods including KRR (0.117). The combination of geometric column attention (feature groups) and geometric row attention (sample similarity) gives the strongest ICL

## Step 5: Geometric Algebra Kernels — cos($\theta$) and $\theta$ (Cayley)

Standard attention uses the raw dot product as its kernel. But we can design **geometrically motivated kernels** that measure similarity differently.

From `ga_transformer`, we bring two attention mechanisms that operate on **normalized** vectors and use the **rotation angle** $\theta$ between them:

| Kernel | Score | Interpretation |
|---|---|---|
| **Cosine** | $\cos(\theta) / T$ | Standard cosine similarity — alignment |
| **Cayley ($\theta$)** | $-\theta^2 / (2T^2)$ | Squared rotation angle — like an RBF in angle space |
| **GA (inner + wedge)** | $\beta_1 \cos\theta + f(\sin\theta)$ | Alignment + orthogonality structure |

The Cayley kernel is particularly interesting: $\text{score} = -\theta^2/(2T^2)$ is a **Gaussian kernel in angle space** — nearby directions get high weight, distant directions get low weight, just like RBF but on the hypersphere.

In [None]:
import sys
sys.path.insert(0, "/home/asudjianto/jupyterlab/ga_transformer")
from ga_icl.attention import CayleyAttention, GAAttention, CosineAttention

# --- GA-based ICL predictors (Nadaraya-Watson with GA kernels) ---
# These take raw X, normalize, compute attention weights, then ŷ = weights @ y

def cosine_icl(X_ctx, y_ctx, X_query, temperature=0.1):
    """ICL using cosine similarity kernel: score = cos(θ) / T"""
    attn = CosineAttention(temperature=temperature)
    z_ctx   = X_ctx / X_ctx.norm(dim=1, keepdim=True).clamp(min=1e-8)
    z_query = X_query / X_query.norm(dim=1, keepdim=True).clamp(min=1e-8)
    weights = attn(z_query, z_ctx)
    return weights @ y_ctx

def cayley_icl(X_ctx, y_ctx, X_query, temperature=0.1):
    """ICL using Cayley (θ) kernel: score = -θ² / (2T²)
    
    θ = arccos(cos(θ)) = rotation angle between normalized vectors.
    This is a Gaussian kernel in angle space.
    """
    attn = CayleyAttention(temperature=temperature)
    z_ctx   = X_ctx / X_ctx.norm(dim=1, keepdim=True).clamp(min=1e-8)
    z_query = X_query / X_query.norm(dim=1, keepdim=True).clamp(min=1e-8)
    weights = attn(z_query, z_ctx)
    return weights @ y_ctx

def ga_icl(X_ctx, y_ctx, X_query, beta1=4.0, beta2=1.0, temperature=1.0,
           wedge_mode="penalty"):
    """ICL using GA kernel: score = β₁·cos(θ) + f(sin(θ))"""
    attn = GAAttention(beta1=beta1, beta2=beta2, temperature=temperature,
                       wedge_mode=wedge_mode)
    z_ctx   = X_ctx / X_ctx.norm(dim=1, keepdim=True).clamp(min=1e-8)
    z_query = X_query / X_query.norm(dim=1, keepdim=True).clamp(min=1e-8)
    weights = attn(z_query, z_ctx)
    return weights @ y_ctx

print("GA attention modules loaded.")

### Comparing GA kernels on the sin(x) task

Same context, same queries — we just swap the kernel that computes the attention weights.

**Important:** These kernels operate on **normalized** vectors. For 1-D input, $x / \|x\|$ collapses to $\pm 1$ (sign only). So we need to embed our scalar $x$ into a higher-dimensional space first — we'll use a simple random Fourier feature embedding to give the kernels something meaningful to work with.

In [None]:
# Embed scalar x into higher-D space so normalization is meaningful.
# Random Fourier Features: φ(x) = [cos(w₁x), sin(w₁x), cos(w₂x), sin(w₂x), ...]
# This is a standard trick for making RBF-like features explicit.

torch.manual_seed(42)
n_rff = 32  # number of random frequencies
W_rff = torch.randn(1, n_rff, device=device, dtype=dtype) * 2.0  # random frequencies

def rff_embed(X):
    """Random Fourier Features: (n, 1) → (n, 2*n_rff)"""
    proj = X @ W_rff  # (n, n_rff)
    return torch.cat([torch.cos(proj), torch.sin(proj)], dim=1)

# Re-use same test data
torch.manual_seed(42)
n_ctx = 200
X_ctx = torch.linspace(-3, 3, n_ctx, device=device, dtype=dtype).unsqueeze(1)
y_ctx = torch.sin(X_ctx) + 0.2 * torch.randn(n_ctx, 1, device=device, dtype=dtype)
n_query = 100
X_query = torch.linspace(-3, 3, n_query, device=device, dtype=dtype).unsqueeze(1)
y_true  = torch.sin(X_query)

# Embed into RFF space
Z_ctx   = rff_embed(X_ctx)
Z_query = rff_embed(X_query)

# --- Run all GA kernels ---
pred_cosine = cosine_icl(Z_ctx, y_ctx, Z_query, temperature=0.1)
pred_cayley = cayley_icl(Z_ctx, y_ctx, Z_query, temperature=0.1)
pred_ga_pen = ga_icl(Z_ctx, y_ctx, Z_query, beta1=4.0, beta2=1.0,
                     temperature=1.0, wedge_mode="penalty")

# Baselines for comparison
pred_krr  = KernelRidge(kernel_fn=rbf_kernel, lam=0.01, sigma=0.5).fit(X_ctx, y_ctx).predict(X_query)
pred_nw   = nadaraya_watson(X_ctx, y_ctx, X_query, rbf_kernel, sigma=0.5)

mse = lambda a, b: ((a - b) ** 2).mean().item()
print("MSE on query points:")
print(f"  KRR (RBF):              {mse(y_true, pred_krr):.6f}")
print(f"  Nadaraya-Watson (RBF):  {mse(y_true, pred_nw):.6f}")
print(f"  Cosine ICL (cos θ/T):   {mse(y_true, pred_cosine):.6f}")
print(f"  Cayley ICL (-θ²/2T²):   {mse(y_true, pred_cayley):.6f}")
print(f"  GA ICL (penalty):       {mse(y_true, pred_ga_pen):.6f}")
print(f"  Pre-trained attention:  {mse(y_true, pred_pretrained):.6f}")

In [None]:
fig, axes = plt.subplots(2, 3, figsize=(16, 7), sharex=True, sharey=True)

xc = X_ctx.cpu().numpy().ravel()
yc = y_ctx.cpu().numpy().ravel()
xq = X_query.cpu().numpy().ravel()
yt = y_true.cpu().numpy().ravel()

all_titles = [
    "KRR (RBF)", "Nadaraya-Watson (RBF)", "Pre-trained attention",
    r"Cosine ICL ($\cos\theta / T$)", r"Cayley ICL ($-\theta^2/2T^2$)",
    r"GA ICL ($\beta_1\cos\theta - \beta_2\sin\theta$)",
]
all_preds = [pred_krr, pred_nw, pred_pretrained,
             pred_cosine, pred_cayley, pred_ga_pen]

for ax, title, pred in zip(axes.ravel(), all_titles, all_preds):
    ax.scatter(xc, yc, s=5, alpha=0.3, color="gray", label="context")
    ax.plot(xq, yt, "k--", lw=1, label="true sin(x)")
    ax.plot(xq, pred.cpu().detach().numpy().ravel(), "r-", lw=2, label="predicted")
    cur_mse = mse(y_true, pred)
    ax.set_title(f"{title}\nMSE = {cur_mse:.4f}", fontsize=10)
    ax.legend(fontsize=7)

fig.suptitle("GA kernels as in-context learning", fontsize=14)
plt.tight_layout()
plt.show()

### Kernel weight profiles: RBF vs Cosine vs Cayley ($\theta$)

How do these kernels distribute attention across context points for a single query?

In [None]:
# Weight profiles for a single query point
q_idx = 65
zq = Z_query[q_idx:q_idx+1]
zc = Z_ctx

# Normalize for GA kernels
eps = 1e-8
zq_n = zq / zq.norm(dim=1, keepdim=True).clamp(min=eps)
zc_n = zc / zc.norm(dim=1, keepdim=True).clamp(min=eps)

# RBF weights (on raw X)
xq_s = X_query[q_idx:q_idx+1]
K_rbf = rbf_kernel(xq_s, X_ctx, sigma=0.5)
w_rbf = (K_rbf / K_rbf.sum()).cpu().numpy().ravel()

# Cosine weights
w_cos = CosineAttention(temperature=0.1)(zq_n, zc_n).detach().cpu().numpy().ravel()

# Cayley weights
w_cay = CayleyAttention(temperature=0.1)(zq_n, zc_n).detach().cpu().numpy().ravel()

# GA penalty weights
w_ga = GAAttention(beta1=4.0, beta2=1.0, temperature=1.0,
                   wedge_mode="penalty")(zq_n, zc_n).detach().cpu().numpy().ravel()

fig, axes = plt.subplots(1, 4, figsize=(18, 3.5), sharex=True)

query_val = X_query[q_idx].item()
for ax, w, title in zip(axes,
        [w_rbf, w_cos, w_cay, w_ga],
        ["RBF kernel", r"Cosine ($\cos\theta/T$)",
         r"Cayley ($-\theta^2/2T^2$)", r"GA ($\beta_1\cos\theta - \beta_2\sin\theta$)"]):
    ax.bar(xc, w, width=0.04, color="steelblue", alpha=0.7)
    ax.axvline(query_val, color="red", ls="--", lw=2, label=f"query = {query_val:.2f}")
    ax.set_title(title, fontsize=10)
    ax.set_xlabel("context x")
    ax.legend(fontsize=8)

axes[0].set_ylabel("weight")
fig.suptitle("Attention weight profiles for a single query point", fontsize=13)
plt.tight_layout()
plt.show()

## Step 6: Higher Dimensions — Column Attention for Feature Interactions

Everything so far used 1-D input. In higher dimensions ($x \in \mathbb{R}^d$), features can form **correlated groups** — and the task may depend on these hidden groups rather than on individual features.

Standard attention computes **row attention** (query vs context). But GA introduces **column attention** (feature vs feature):

| | Row attention | Column attention |
|---|---|---|
| **Compares** | Samples to samples | Features to features |
| **Matrix** | $(n \times n)$ | $(d \times d)$ |
| **Purpose** | Which context examples matter? | Which features belong together? |
| **GA insight** | $\cos\theta$ = sample similarity | $\sin\theta$ (wedge) = feature independence |

The pipeline:
1. **Column attention** discovers which features are correlated and mixes them — denoising the group signal
2. **Row attention** uses the cleaner representations to find relevant context examples
3. Prediction = weighted sum of context labels

Our task distribution creates **random feature groups**: a random subset of features shares a latent source, and the target depends on this hidden group. Column attention can discover the group structure (correlated features have similar activation patterns across the $n$ context samples) and denoise by averaging within groups.

In [None]:
from ga_icl.attention import GAColumnAttention

# --- High-D task: random correlated feature groups ---
# Column attention should discover which features are correlated and mix them.

input_dim = 10   # more features → more room for column attention to help

def sample_hd_task_batch(batch_size, n_ctx, n_query, input_dim=10, noise_std=0.1):
    """Sample tasks with random correlated feature groups.
    
    Each task:
      1. Pick a random subset of k features to be "group" features
      2. These k features all share a latent signal z (plus individual noise)
      3. Target = a * sin(b * z + phi) — depends on the latent, not individual features
      4. Non-group features are pure noise (distractors)
    
    Column attention can discover the group (correlated features have similar
    activation patterns) and average them to denoise z.
    Without column attention, the model must implicitly figure out which
    features matter from the context — much harder.
    """
    k = 3  # group size (k of input_dim features share a latent)
    
    # Random function parameters per task
    a   = 0.5 + 1.5 * torch.rand(batch_size, 1, 1, device=device)
    b   = 0.5 + 2.0 * torch.rand(batch_size, 1, 1, device=device)
    phi = 2 * 3.14159 * torch.rand(batch_size, 1, 1, device=device)
    
    n_total = n_ctx + n_query
    
    # Latent signal shared by group features
    z = torch.randn(batch_size, n_total, 1, device=device)
    
    # Build X: group features = z + noise, distractor features = pure noise
    X_all = torch.randn(batch_size, n_total, input_dim, device=device) * 0.5  # all noise baseline
    
    # For each batch, pick k random features to be the group
    for bi in range(batch_size):
        group_idx = torch.randperm(input_dim, device=device)[:k]
        for fi in group_idx:
            X_all[bi, :, fi] = z[bi, :, 0] + 0.3 * torch.randn(n_total, device=device)
    
    # Target depends on latent signal
    y_all = a * torch.sin(b * z + phi)
    
    X_ctx   = X_all[:, :n_ctx]
    y_ctx   = y_all[:, :n_ctx] + noise_std * torch.randn_like(y_all[:, :n_ctx])
    X_query = X_all[:, n_ctx:]
    y_query = y_all[:, n_ctx:]
    return X_ctx, y_ctx, X_query, y_query

X_c, y_c, X_q, y_q = sample_hd_task_batch(4, 100, 20, input_dim=input_dim)
print(f"Task: 3 of {input_dim} features share a latent → y = a·sin(b·z + φ)")
print(f"Context: X {tuple(X_c.shape)}, y {tuple(y_c.shape)}")
print(f"Query:   X {tuple(X_q.shape)}, y {tuple(y_q.shape)}")

### Unified ICL model: Column attention × Row kernel

We systematically compare all combinations of:
- **Column attention**: off (row-only) vs on (column+row)
- **Row kernel**: Softmax, Cosine ($\cos\theta/T$), Cayley ($-\theta^2/2T^2$), GA ($\beta_1\cos\theta - \beta_2\sin\theta$)

The GA kernel parameters in row attention are **learnable** — they adapt during training just like the column attention parameters. This gives the geometric kernels a fair chance.

In [None]:
import torch.nn as nn
import torch.nn.functional as F

class LearnableColumnAttention(nn.Module):
    """GA Column Attention with LEARNABLE parameters + residual gate."""
    def __init__(self, d_model, init_beta1=4.0, init_beta2=1.0,
                 init_diag_bias=2.0, init_wedge_mu=0.5):
        super().__init__()
        self.beta1 = nn.Parameter(torch.tensor(init_beta1))
        self.beta2 = nn.Parameter(torch.tensor(init_beta2))
        self.diag_bias = nn.Parameter(torch.tensor(init_diag_bias))
        self.wedge_mu = nn.Parameter(torch.tensor(init_wedge_mu))
        self.gate = nn.Parameter(torch.tensor(0.0))
    
    def forward(self, X):
        B, N, D = X.shape
        eps = 1e-6
        X_t = X.transpose(1, 2)
        u = X_t / X_t.norm(dim=-1, keepdim=True).clamp(min=eps)
        c = torch.einsum("bdn,ben->bde", u, u).clamp(-1 + eps, 1 - eps)
        w = torch.sqrt((1 - c ** 2).clamp(min=0))
        score = (self.beta1 * c
                 - self.beta2 * (w - self.wedge_mu) ** 2
                 + self.diag_bias * torch.eye(D, device=X.device).unsqueeze(0))
        alpha = F.softmax(score, dim=-1)
        Y = torch.einsum("bde,ben->bdn", alpha, X_t).transpose(1, 2)
        g = torch.sigmoid(self.gate)
        return (1 - g) * X + g * Y


class UnifiedICL(nn.Module):
    """Unified ICL model: {column attention} × {row kernel type}.
    
    Args:
        input_dim: Raw input dimension
        d_model: Embedding dimension
        use_column_attn: Whether to apply GA column attention on raw features
        row_kernel: "softmax", "cosine", "cayley", or "ga"
    """
    def __init__(self, input_dim=10, d_model=64, use_column_attn=False,
                 row_kernel="softmax"):
        super().__init__()
        self.use_column_attn = use_column_attn
        self.row_kernel = row_kernel
        
        if use_column_attn:
            self.col_attn = LearnableColumnAttention(d_model=input_dim)
        
        self.embed_x = nn.Sequential(
            nn.Linear(input_dim, d_model), nn.ReLU(), nn.Linear(d_model, d_model))
        self.embed_y = nn.Sequential(
            nn.Linear(1, d_model), nn.ReLU(), nn.Linear(d_model, d_model))
        
        self.W_Q = nn.Linear(d_model, d_model, bias=False)
        self.W_K = nn.Linear(d_model, d_model, bias=False)
        self.W_V = nn.Linear(d_model, d_model, bias=False)
        self.out = nn.Linear(d_model, 1)
        self.scale = d_model ** 0.5
        
        # Learnable row-kernel parameters
        if row_kernel in ("cosine", "cayley"):
            self.temperature = nn.Parameter(torch.tensor(0.1))
        if row_kernel == "ga":
            self.temperature = nn.Parameter(torch.tensor(1.0))
            self.ga_beta1 = nn.Parameter(torch.tensor(4.0))
            self.ga_beta2 = nn.Parameter(torch.tensor(1.0))
    
    def _row_weights(self, Q, K):
        """Compute row attention weights using the chosen kernel."""
        if self.row_kernel == "softmax":
            return F.softmax(Q @ K.transpose(-2, -1) / self.scale, dim=-1)
        
        # Normalize for angle-based kernels
        eps = 1e-6
        Q_n = Q / Q.norm(dim=-1, keepdim=True).clamp(min=eps)
        K_n = K / K.norm(dim=-1, keepdim=True).clamp(min=eps)
        cos_theta = (Q_n @ K_n.transpose(-2, -1)).clamp(-1 + eps, 1 - eps)
        
        if self.row_kernel == "cosine":
            score = cos_theta / self.temperature
        elif self.row_kernel == "cayley":
            theta = torch.acos(cos_theta)
            score = -theta ** 2 / (2 * self.temperature ** 2)
        elif self.row_kernel == "ga":
            sin_theta = torch.sqrt((1 - cos_theta ** 2).clamp(min=0))
            score = (self.ga_beta1 * cos_theta - self.ga_beta2 * sin_theta) / self.temperature
        
        return F.softmax(score, dim=-1)
    
    def forward(self, X_ctx, y_ctx, X_query):
        # Optional column attention on raw features
        if self.use_column_attn:
            X_ctx = self.col_attn(X_ctx)
            X_query = self.col_attn(X_query)
        
        # Embed
        ctx_x = self.embed_x(X_ctx)
        qry_x = self.embed_x(X_query)
        ctx_y = self.embed_y(y_ctx)
        
        # Row attention with chosen kernel
        Q = self.W_Q(qry_x)
        K = self.W_K(ctx_x)
        V = self.W_V(ctx_y)
        weights = self._row_weights(Q, K)
        return self.out(weights @ V)

# Verify it works
m = UnifiedICL(input_dim=10, use_column_attn=True, row_kernel="ga").to(device).float()
X_c, y_c, X_q, y_q = sample_hd_task_batch(2, 50, 10, input_dim=10)
print(f"UnifiedICL output shape: {m(X_c, y_c, X_q).shape}")
print(f"Row kernel types: softmax, cosine, cayley, ga")
print(f"Column attention: on/off")

### Training all 8 combinations

We train every combination of `{Row-only, Column+Row} × {Softmax, Cosine, Cayley, GA}` on the same task distribution. This gives a clean 2×4 comparison.

In [None]:
def train_model(model, n_steps=10000, batch_size=64, n_ctx=100, n_query=20,
                input_dim=10, lr=1e-3):
    """Train a model on random high-dimensional tasks."""
    optimizer = torch.optim.Adam(model.parameters(), lr=lr)
    losses = []
    for step in range(n_steps):
        X_c, y_c, X_q, y_q = sample_hd_task_batch(
            batch_size, n_ctx, n_query, input_dim=input_dim)
        pred = model(X_c.float(), y_c.float(), X_q.float())
        loss = ((pred - y_q.float()) ** 2).mean()
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
        losses.append(loss.item())
    return losses

torch.manual_seed(42)

row_kernels = ["softmax", "cosine", "cayley", "ga"]
models = {}      # (col_attn, kernel) → model
all_losses = {}  # (col_attn, kernel) → losses

for use_col in [False, True]:
    for kernel in row_kernels:
        tag = ("Col+" if use_col else "") + kernel.capitalize()
        print(f"Training {tag}...", end=" ", flush=True)
        m = UnifiedICL(input_dim=input_dim, d_model=64,
                       use_column_attn=use_col, row_kernel=kernel).to(device).float()
        losses = train_model(m, n_steps=10000, input_dim=input_dim)
        models[(use_col, kernel)] = m
        all_losses[(use_col, kernel)] = losses
        print(f"final loss: {losses[-1]:.4f}")

# Show learned column attention parameters for Col+Softmax
ca = models[(True, "softmax")].col_attn
print(f"\nLearned column attention params (Col+Softmax):")
print(f"  beta1={ca.beta1.item():.2f}  beta2={ca.beta2.item():.2f}  "
      f"diag_bias={ca.diag_bias.item():.2f}  wedge_mu={ca.wedge_mu.item():.2f}  "
      f"gate={torch.sigmoid(ca.gate).item():.2f}")

# Training curves
fig, axes = plt.subplots(1, 4, figsize=(18, 3.5), sharey=True)
window = 200
colors_row = {"softmax": "coral", "cosine": "#e67e22", "cayley": "#e74c3c", "ga": "#c0392b"}
colors_col = {"softmax": "steelblue", "cosine": "#2980b9", "cayley": "#3498db", "ga": "#1abc9c"}

for ax, kernel in zip(axes, row_kernels):
    for use_col, cmap in [(False, colors_row), (True, colors_col)]:
        losses = all_losses[(use_col, kernel)]
        sm = [sum(losses[max(0,i-window):i+1])/len(losses[max(0,i-window):i+1])
              for i in range(len(losses))]
        label = ("Col+" if use_col else "Row ") + kernel
        ax.plot(losses, alpha=0.08, color=cmap[kernel])
        ax.plot(sm, color=cmap[kernel], lw=2, label=label)
    ax.set_title(kernel.capitalize(), fontsize=12)
    ax.set_xlabel("Step")
    ax.legend(fontsize=8)
axes[0].set_ylabel("MSE loss")
fig.suptitle(f"Training curves: Row-only vs Column+Row × kernel type (dim={input_dim})", fontsize=13)
plt.tight_layout()
plt.show()

### Evaluation: 2×4 comparison on held-out tasks

Test all 8 trained models + KRR baseline on 200 new tasks. This reveals:
- Which **row kernel** works best for ICL?
- Does **column attention** help consistently across kernel types?

In [None]:
torch.manual_seed(123)

n_eval_tasks = 200
n_ctx_eval = 200
n_query_eval = 50

# Result containers
result_names = ["KRR (RBF)"]
for use_col in [False, True]:
    for kernel in row_kernels:
        tag = ("Col+" if use_col else "Row ") + kernel.capitalize()
        result_names.append(tag)
results = {name: [] for name in result_names}

for m in models.values():
    m.eval()

for t in range(n_eval_tasks):
    X_c, y_c, X_q, y_q = sample_hd_task_batch(
        1, n_ctx_eval, n_query_eval, input_dim=input_dim, noise_std=0.1)
    
    xc = X_c.squeeze(0).double()
    yc = y_c.squeeze(0).double()
    xq = X_q.squeeze(0).double()
    yq = y_q.squeeze(0).double()
    
    mse_fn = lambda pred: ((pred - yq) ** 2).mean().item()
    
    # KRR baseline
    results["KRR (RBF)"].append(
        mse_fn(KernelRidge(kernel_fn=rbf_kernel, lam=0.01, sigma=3.0).fit(xc, yc).predict(xq)))
    
    # All 8 trained models
    with torch.no_grad():
        for use_col in [False, True]:
            for kernel in row_kernels:
                tag = ("Col+" if use_col else "Row ") + kernel.capitalize()
                m = models[(use_col, kernel)]
                pred = m(X_c.float(), y_c.float(), X_q.float()).squeeze(0).double()
                results[tag].append(mse_fn(pred))

# Print results table
print(f"Mean MSE over {n_eval_tasks} held-out tasks (dim = {input_dim}):\n")
print(f"  {'Method':20s}  {'MSE':>8s}")
print(f"  {'─'*20}  {'─'*8}")
for name in result_names:
    mean_mse = sum(results[name]) / len(results[name])
    print(f"  {name:20s}  {mean_mse:.4f}")

# Column attention improvement per kernel
print(f"\nColumn attention improvement per kernel:")
for kernel in row_kernels:
    row_tag = "Row " + kernel.capitalize()
    col_tag = "Col+" + kernel.capitalize()
    row_mse = sum(results[row_tag]) / len(results[row_tag])
    col_mse = sum(results[col_tag]) / len(results[col_tag])
    pct = 100 * (row_mse - col_mse) / row_mse
    print(f"  {kernel:8s}:  Row {row_mse:.4f} → Col+Row {col_mse:.4f}  ({pct:+.1f}%)")

In [None]:
# Grouped bar chart: Row-only vs Column+Row for each kernel
fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(14, 5))

# Left: Box plot of all methods
names = result_names
data = [results[n] for n in names]
bp = ax1.boxplot(data, labels=[n.replace("Row ", "R:").replace("Col+", "C+") for n in names],
                 patch_artist=True, showfliers=False, medianprops=dict(color="black", lw=2))
colors = (["#d4d4d4"]                                         # KRR
          + ["#ffb3b3", "#ffc9a3", "#ffa3a3", "#ff8080"]      # Row: softmax,cos,cay,ga
          + ["#a3c4ff", "#93b8ff", "#83acff", "#73a0ff"])      # Col: softmax,cos,cay,ga
for patch, color in zip(bp["boxes"], colors):
    patch.set_facecolor(color)
means = [sum(v)/len(v) for v in data]
ax1.scatter(range(1, len(means)+1), means, color="red", zorder=5, s=40, marker="D")
ax1.set_ylabel("MSE")
ax1.set_title("All methods", fontsize=12)
ax1.tick_params(axis='x', rotation=45)

# Right: Paired bar chart — Row vs Col+Row per kernel
import numpy as np
x = np.arange(len(row_kernels))
width = 0.35
row_means = [sum(results["Row " + k.capitalize()])/n_eval_tasks for k in row_kernels]
col_means = [sum(results["Col+" + k.capitalize()])/n_eval_tasks for k in row_kernels]
krr_mean  = sum(results["KRR (RBF)"]) / n_eval_tasks

bars1 = ax2.bar(x - width/2, row_means, width, label="Row only", color="coral", alpha=0.85)
bars2 = ax2.bar(x + width/2, col_means, width, label="Col+Row", color="steelblue", alpha=0.85)
ax2.axhline(krr_mean, color="gray", ls="--", lw=1.5, label=f"KRR (RBF) = {krr_mean:.3f}")
ax2.set_xticks(x)
ax2.set_xticklabels([k.capitalize() for k in row_kernels])
ax2.set_ylabel("Mean MSE")
ax2.set_xlabel("Row kernel type")
ax2.set_title("Column attention benefit per kernel", fontsize=12)
ax2.legend(fontsize=9)

# Add improvement % labels
for i, (r, c) in enumerate(zip(row_means, col_means)):
    pct = 100 * (r - c) / r
    ax2.annotate(f"{pct:+.0f}%", xy=(i + width/2, c), xytext=(0, 5),
                 textcoords="offset points", ha="center", fontsize=9, fontweight="bold",
                 color="darkblue")

fig.suptitle(f"ICL performance: Row-only vs Column+Row × kernel type (dim={input_dim})", fontsize=13)
plt.tight_layout()
plt.show()