# Emotion-Specific Attention (ESA)

Emotion-Specific Attention (ESA) is an attention-based layer introduced in the paper
*Emotion-Aware RoBERTa enhanced with emotion-specific attention and TF-IDF gating*.

ESA is designed to **refine contextual token representations by amplifying
emotion-relevant latent features**, without modifying the internal attention
mechanism of RoBERTa.

Key ideas:
- ESA operates **after** RoBERTa has produced contextual embeddings
- It uses **standard self-attention**
- Emotional awareness is introduced via a **learnable feature-scaling vector**
- ESA does **not** assign importance directly to tokens, but to embedding dimensions

ESA is task-specific and trained end-to-end via the final emotion classification loss.


In [27]:
import torch
import torch.nn as nn
import torch.nn.functional as F
from typing import Optional

## Dummy input example

In [28]:
BATCH_SIZE = 2
SEQ_LEN = 16
HIDDEN_DIM = 768   

print("Batch size:", BATCH_SIZE)
print("Sequence length:", SEQ_LEN)
print("Hidden dimension:", HIDDEN_DIM)

Batch size: 2
Sequence length: 16
Hidden dimension: 768


In [29]:
# Simulated RoBERTa output embeddings
# Shape: [B, L, H]
E = torch.randn(BATCH_SIZE, SEQ_LEN, HIDDEN_DIM)

print("E shape:", E.shape)

E shape: torch.Size([2, 16, 768])


In [35]:
# Attention mask: 1 = real token, 0 = padding
attention_mask = torch.tensor([
    [1]*12 + [0]*4,     # 12 real tokens + 4 pads
    [1]*16              # all real
])

print("Attention mask shape:", attention_mask.shape)
print(attention_mask)

Attention mask shape: torch.Size([2, 16])
tensor([[1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 0, 0, 0, 0],
        [1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1]])


# Next Step

We are now ready to implement the ESA layer itself.

Next we will:
1. Add positional encoding
2. Implement standard scaled dot-product attention
3. Introduce the learnable emotion-scaling vector S ∈ ℝᴴ
4. Apply feature-wise scaling
5. Re-add positional encoding
6. Verify shapes and gradients

No TF-IDF gating, no training loop, no optimization yet.

## 1. Position Encoding

In [31]:
# OPTIONAL / EDUCATIONAL ONLY
# This demonstrates how positional encodings are typically constructed in vanilla Transformers.
# We will NOT apply this to RoBERTa outputs (last_hidden_state), because RoBERTa already includes
# learned positional embeddings internally.

class SinusoidalPositionalEncoding(nn.Module):
    def __init__(self, hidden_dim: int, max_len: int = 512):
        super().__init__()

        position = torch.arange(0, max_len).unsqueeze(1)  # [max_len, 1]
        div_term = torch.exp(
            torch.arange(0, hidden_dim, 2) * (-torch.log(torch.tensor(10000.0)) / hidden_dim)
        )

        pe = torch.zeros(max_len, hidden_dim)
        pe[:, 0::2] = torch.sin(position * div_term)
        pe[:, 1::2] = torch.cos(position * div_term)

        self.register_buffer("pe", pe)

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        # x: [B, L, H]
        seq_len = x.size(1)
        return self.pe[:seq_len]  # [L, H]


# Quick demo (again: NOT used later)
pos_enc = SinusoidalPositionalEncoding(hidden_dim=HIDDEN_DIM, max_len=SEQ_LEN)
P_demo = pos_enc(E)  # [L, H]
print("P_demo shape:", P_demo.shape)


P_demo shape: torch.Size([16, 768])


## Why the paper writes `E_input = E + P`, and what we do in practice

In the ESA formulation, the paper explicitly adds positional encodings:

`E_input = E + P`

This is a standard description used in Transformer attention, where token embeddings `E`
need positional information `P` because attention alone is order-invariant.

However, in our implementation `E` comes from:

`E = roberta_outputs.last_hidden_state`

RoBERTa already adds **learned positional embeddings** inside its embedding layer
(before the transformer blocks). Therefore, `last_hidden_state` already contains
positional information.

So in practice, for ESA built *on top of RoBERTa outputs*, we do:

- `E_input = E`

and we do **not** add an additional external positional encoding, to avoid double-counting position.

This notebook keeps an example positional encoding implementation above for reference, but it is not used in the ESA forward pass.


In [32]:
# This is the actual input to ESA in our setup:
E_input = E  # RoBERTa already injected positional information
print("E_input shape:", E_input.shape)

E_input shape: torch.Size([2, 16, 768])


## 2. Self-attention module

In [33]:
class ScaledDotProductSelfAttention(nn.Module):
    """
    Minimal single-head self-attention:
    Input:  E_input  [B, L, H]
    Mask:   attention_mask [B, L] with 1=real token, 0=pad
    Output: Z [B, L, H] and attn_probs [B, L, L]
    """
    def __init__(self, hidden_dim: int):
        super().__init__()
        self.hidden_dim = hidden_dim

        # Linear projections
        self.Wq = nn.Linear(hidden_dim, hidden_dim, bias=False)
        self.Wk = nn.Linear(hidden_dim, hidden_dim, bias=False)
        self.Wv = nn.Linear(hidden_dim, hidden_dim, bias=False)

    def forward(self, E_input: torch.Tensor, attention_mask: Optional[torch.Tensor] = None):
        B, L, H = E_input.size()
        assert H == self.hidden_dim, f"Expected hidden_dim={self.hidden_dim}, got H={H}"

        # 1. Project to Q, K, V
        Q = self.Wq(E_input)  # [B, L, H]
        K = self.Wk(E_input)  # [B, L, H]
        V = self.Wv(E_input)  # [B, L, H]

        # 2. Scaled dot-product attention scores
        # scores[b] = Q[b] @ K[b].T => [L, L]
        scores = torch.matmul(Q, K.transpose(-1, -2)) / (H ** 0.5)  # [B, L, L]

        # 3. Apply attention mask (if provided)
        if attention_mask is not None:
            # attention_mask: [B, L] -> key_mask: [B, 1, L] brodcast over query positions
            key_mask = attention_mask.unsqueeze(1)  # [B, 1, L]
            scores = scores.masked_fill(key_mask == 0, float('-inf'))

        # 4. Softmax to get attention probabilities
        attn_probs = F.softmax(scores, dim=-1)  # [B, L, L]

        # If a row in attn_probs is all -inf (due to masking), softmax gives NaNs.
        # This should not happen if each sample has at least 1 real token.
        if torch.isnan(attn_probs).any():
            raise ValueError("NaNs detected in attn_probs. Check attention_mask and padding.")

        # 5. Weighted sum of values
        Z = torch.matmul(attn_probs, V)  # [B, L, H]

        return Z, attn_probs

In [36]:
attn = ScaledDotProductSelfAttention(hidden_dim=HIDDEN_DIM)

Z, attn_probs = attn(E_input, attention_mask=attention_mask)

print("Z shape:", Z.shape)                 # [B, L, H]
print("attn_probs shape:", attn_probs.shape)  # [B, L, L]

Z shape: torch.Size([2, 16, 768])
attn_probs shape: torch.Size([2, 16, 16])


In [37]:
# For sample 0, tokens at positions 4 and 5 are PAD (mask=0).
# So attention probabilities toward those positions should be ~0.
print("Attention mask (sample 0):", attention_mask[0].tolist())
print("Sum of attention mass going to PAD keys (sample 0):",
      attn_probs[0, :, 4:].sum().item())


Attention mask (sample 0): [1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 0, 0, 0, 0]
Sum of attention mass going to PAD keys (sample 0): 10.67730712890625


## 3. learnable emotion-scaling vector S ∈ ℝᴴ

In [38]:
class ESALayer(nn.Module):
    """
    ESA = standard self-attention + learnable feature scaling vector S ∈ R^H.

    Input:
      E_input: [B, L, H]  (RoBERTa last_hidden_state)
      attention_mask: [B, L] (1=real, 0=pad)

    Output:
      Z_scaled: [B, L, H]
      attn_probs: [B, L, L]  (optional, useful for debugging/visualization)
    """
    def __init__(self, hidden_dim: int, num_heads: int = 2, max_len: int = 512):
        super().__init__()
        self.hidden_dim = hidden_dim
        self.num_heads = num_heads
        self.max_len = max_len

        # Learned positional encoding P ∈ R[max_len, H]
        self.pos_emb = nn.Embedding(max_len, hidden_dim)

        # Multi-head self-attention (Transformer-style)
        self.mha = nn.MultiheadAttention(embed_dim=hidden_dim, num_heads=num_heads, batch_first=True)

        # Learnable scaling vector S ∈ R[H]
        self.S = nn.Parameter(torch.ones(hidden_dim))

    def forward(self, E: torch.Tensor, attention_mask: Optional[torch.Tensor] = None):
        """
        E: [B, L, H]  (e.g., roberta_outputs.last_hidden_state)
        attention_mask: [B, L] with 1=real token, 0=pad
        """
        B, L, H = E.shape
        assert H == self.hidden_dim
        assert L <= self.max_len, f"Sequence length {L} exceeds max_len {self.max_len}"

        # Build positions [L] and expand to [B, L] for embedding lookup
        positions = torch.arange(L, device=E.device).unsqueeze(0).expand(B, L)  # [B, L]
        P = self.pos_emb(positions)  # [B, L, H]

        # (1) E_input = E + P
        E_input = E + P

        # Prepare key padding mask for MHA: True means "ignore"
        key_padding_mask = None
        if attention_mask is not None:
            key_padding_mask = (attention_mask == 0)  # [B, L] boolean

        # (2) Standard attention
        Z, attn_weights = self.mha(
            E_input, E_input, E_input,
            key_padding_mask=key_padding_mask,
            need_weights=True,
            average_attn_weights=False  # returns per-head weights (closer to attention analysis)
        )  # Z: [B, L, H]

        # (3) Emotion-specific scaling: Z_scaled = Z ⊙ S
        Z_scaled = Z * self.S  # broadcasts [H] -> [B, L, H]

        # (5) Re-add positional encoding: Z_final = Z_scaled + P
        Z_final = Z_scaled + P

        return Z_final, attn_weights

In [43]:
esa = ESALayer(hidden_dim=HIDDEN_DIM, num_heads=12, max_len=514)
Z_final, attn_w = esa(E_input, attention_mask=attention_mask)

print("Z_final:", Z_final.shape)        # [B, L, H]
print("attn_w:", attn_w.shape)          # [B, num_heads, L, L]
print("S:", esa.S.shape)                # [H]

Z_final: torch.Size([2, 16, 768])
attn_w: torch.Size([2, 12, 16, 16])
S: torch.Size([768])


In [47]:
esa = ESALayer(hidden_dim=HIDDEN_DIM, num_heads=12, max_len=514)
esa.train()
Z_final, _ = esa(E_input, attention_mask=attention_mask)

loss = Z_final.mean()
esa.zero_grad(set_to_none=True)
loss.backward()

print("S grad exists:", esa.S.grad is not None)
print("pos_emb grad exists:", esa.pos_emb.weight.grad is not None)
if esa.pos_emb.weight.grad is not None:
    print("pos_emb grad norm:", esa.pos_emb.weight.grad.norm().item())

S grad exists: True
pos_emb grad exists: True
pos_emb grad norm: 0.009953795000910759


In [44]:
print("pos_emb grad exists:", esa.pos_emb.weight.grad is not None)
if esa.pos_emb.weight.grad is not None:
    print("pos_emb grad norm:", esa.pos_emb.weight.grad.norm().item())


pos_emb grad exists: False


In [46]:
# Identify pad positions for sample 0
pad_positions = (attention_mask[0] == 0).nonzero(as_tuple=True)[0]
print("Pad positions sample 0:", pad_positions.tolist())

if len(pad_positions) > 0:
    pad_attention_mass = attn_w[0, :, :, pad_positions].sum().item()
    print("Total attention mass going to PAD keys (sample 0):", pad_attention_mass)
else:
    print("No padding in sample 0 to test masking.")


Pad positions sample 0: [12, 13, 14, 15]
Total attention mass going to PAD keys (sample 0): 0.0
