<a href="https://colab.research.google.com/github/D-Keqi/mtla/blob/main/assets/MTLA.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

# Multi-head Temporal Latent Attention (MTLA) Demo

This notebook demonstrates:
1. Training a simple language model using MTLA
2. Implementing beam search decoding with incremental state
3. Showing how the temporal compression works during inference

[GitHub Project](https://github.com/D-Keqi/mtla)

More specific usage examples of MTLA in Fairseq refer to [here](https://github.com/D-Keqi/mtla/blob/main/experiments/tools/fairseq/fairseq/models/transformer/transformer_decoder.py#L1638).


In [1]:
from collections import defaultdict
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch import Tensor
from torch.utils.data import Dataset, DataLoader
from tqdm import tqdm
torch.manual_seed(42)

<torch._C.Generator at 0x782e91501250>

## 1. Load the MTLA Module

In [2]:
!git clone https://github.com/D-Keqi/mtla.git
%cd mtla

Cloning into 'mtla'...
remote: Enumerating objects: 2313, done.[K
remote: Counting objects: 100% (2313/2313), done.[K
remote: Compressing objects: 100% (1753/1753), done.[K
remote: Total 2313 (delta 520), reused 2218 (delta 461), pack-reused 0 (from 0)[K
Receiving objects: 100% (2313/2313), 13.11 MiB | 17.37 MiB/s, done.
Resolving deltas: 100% (520/520), done.
/content/mtla


In [3]:
from MTLA import MultiheadTemporalLatentAttention

## 2. Create a Simple Language Model with MTLA

In [4]:
class TransformerBlock(nn.Module):
    def __init__(self, embed_dim, num_heads, dropout=0.1):
        super().__init__()
        self.self_attn = MultiheadTemporalLatentAttention(
            embed_dim=embed_dim,
            num_heads=num_heads,
            dropout=dropout,
            q_lora_rank=0,
            kv_lora_rank=256,
            qk_nope_head_dim=64,
            qk_rope_head_dim=32,
            v_head_dim=64,
            down_rate=2,
            recompute_prompt_attn=True,
        )
        self.norm1 = nn.LayerNorm(embed_dim)
        self.norm2 = nn.LayerNorm(embed_dim)
        self.ffn = nn.Sequential(
            nn.Linear(embed_dim, 4 * embed_dim),
            nn.GELU(),
            nn.Linear(4 * embed_dim, embed_dim),
            nn.Dropout(dropout)
        )
        self.dropout = nn.Dropout(dropout)

    def forward(self, x, position, incremental_state=None, self_attn_mask=None):
        # Self attention
        residual = x
        x = self.norm1(x)
        x = self.self_attn(
            query=x, #[:,-1:] if incremental_state is not None else x,
            key=x,
            value=x,
            position=position,
            incremental_state=incremental_state,
            self_attn_mask=self_attn_mask
        )
        x = self.dropout(x)
        x = residual + x

        # Feed forward
        residual = x
        x = self.norm2(x)
        x = self.ffn(x)
        x = residual + x

        return x

class SimpleLM(nn.Module):
    def __init__(self, vocab_size, embed_dim=512, num_heads=8, num_layers=6):
        super().__init__()
        self.embedding = nn.Embedding(vocab_size, embed_dim)
        self.layers = nn.ModuleList([
            TransformerBlock(embed_dim, num_heads) for _ in range(num_layers)
        ])
        self.norm = nn.LayerNorm(embed_dim)
        self.head = nn.Linear(embed_dim, vocab_size)

    def forward(self, x, incremental_state=None, positions=None):
        # Get positions if not provided
        if positions is None:
            positions = torch.arange(x.size(1), device=x.device).unsqueeze(0)

        # Embedding and positional encoding
        x = self.embedding(x)

        # Create attention mask for causal (autoregressive) LM
        seq_len = x.size(1)
        self_attn_mask = torch.triu(
            torch.full((seq_len, seq_len), float('-inf')), diagonal=1
        ).to(x.device)

        # Forward through layers
        for layer in self.layers:
            x = layer(
                x,
                position=positions,
                incremental_state=incremental_state,
                self_attn_mask=None, #self_attn_mask if incremental_state is None else None
            )

        x = self.norm(x)
        logits = self.head(x)
        return logits

    def reorder_incremental_state(self, incremental_state, new_order):
        """Reorder incremental state for beam search"""
        if incremental_state is None:
            return

        for layer in self.layers:
            layer.self_attn.reorder_incremental_state(incremental_state, new_order)

## 3. Create a Simple Dataset

In [5]:
class SimpleTextDataset(Dataset):
    def __init__(self, texts, vocab, seq_length=64):
        self.vocab = vocab
        self.seq_length = seq_length
        self.data = []

        for text in texts:
            # Tokenize (simple character-level tokenization)
            tokens = [vocab.get(c, vocab['<unk>']) for c in text]
            # Create sliding windows
            for i in range(0, len(tokens) - seq_length, seq_length // 2):
                chunk = tokens[i:i + seq_length]
                if len(chunk) == seq_length:
                    self.data.append(chunk)

    def __len__(self):
        return len(self.data)

    def __getitem__(self, idx):
        return torch.tensor(self.data[idx], dtype=torch.long)

# Sample data
texts = [
    "To be, or not to be: that is the question",
    "All the world's a stage, and all the men and women merely players",
    "Romeo, Romeo! Wherefore art thou Romeo?",
    "What's in a name? That which we call a rose by any other name would smell as sweet",
    "The lady doth protest too much, methinks",
    "Brevity is the soul of wit",
    "Uneasy lies the head that wears the crown",
    "Parting is such sweet sorrow",
    "Cowards die many times before their deaths",
    "Some are born great, some achieve greatness",
    "The course of true love never did run smooth",
    "All that glitters is not gold",
    "Love looks not with the eyes, but with the mind",
    "Fair is foul, and foul is fair",
    "The better part of valor is discretion",
    "This above all: to thine own self be true",
    "The fault, dear Brutus, is not in our stars, but in ourselves",
    "How sharper than a serpent's tooth it is to have a thankless child",
    "There are more things in heaven and earth, Horatio, than are dreamt of in your philosophy",
    "What's done cannot be undone"
]

# Create vocabulary
chars = sorted(list(set(''.join(texts))))
vocab = {c: i+2 for i, c in enumerate(chars)}
vocab['<pad>'] = 0
vocab['<unk>'] = 1
vocab['<eos>'] = len(vocab)

# Create dataset
dataset = SimpleTextDataset(texts, vocab, seq_length=64)
dataloader = DataLoader(dataset, batch_size=32, shuffle=True)

## 4. Training Loop

In [6]:
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
model = SimpleLM(vocab_size=len(vocab)).to(device)
optimizer = torch.optim.Adam(model.parameters(), lr=1e-4)
criterion = nn.CrossEntropyLoss(ignore_index=vocab['<pad>'])

def train_epoch(model, dataloader):
    model.train()
    total_loss = 0

    for batch in tqdm(dataloader, desc="Training"):
        batch = batch.to(device)
        optimizer.zero_grad()

        # Shift input and target
        inputs = batch[:, :-1]
        targets = batch[:, 1:]

        # Forward pass
        logits = model(inputs)

        # Calculate loss
        loss = criterion(logits.view(-1, logits.size(-1)), targets.reshape(-1))
        loss.backward()
        optimizer.step()

        total_loss += loss.item()

    return total_loss / len(dataloader)

# Train for a few epochs
for epoch in range(10):
    loss = train_epoch(model, dataloader)
    print(f"Epoch {epoch+1}, Loss: {loss:.4f}")

Training: 100%|██████████| 1/1 [00:01<00:00,  1.37s/it]


Epoch 1, Loss: 4.0091


Training: 100%|██████████| 1/1 [00:00<00:00,  1.19it/s]


Epoch 2, Loss: 3.0581


Training: 100%|██████████| 1/1 [00:00<00:00,  1.21it/s]


Epoch 3, Loss: 2.6040


Training: 100%|██████████| 1/1 [00:01<00:00,  1.06s/it]


Epoch 4, Loss: 2.3303


Training: 100%|██████████| 1/1 [00:01<00:00,  1.37s/it]


Epoch 5, Loss: 2.1954


Training: 100%|██████████| 1/1 [00:01<00:00,  1.10s/it]


Epoch 6, Loss: 2.0847


Training: 100%|██████████| 1/1 [00:00<00:00,  1.19it/s]


Epoch 7, Loss: 1.9640


Training: 100%|██████████| 1/1 [00:00<00:00,  1.17it/s]


Epoch 8, Loss: 1.8717


Training: 100%|██████████| 1/1 [00:00<00:00,  1.20it/s]


Epoch 9, Loss: 1.8229


Training: 100%|██████████| 1/1 [00:00<00:00,  1.16it/s]

Epoch 10, Loss: 1.7630





## 5. Beam Search Implementation
This shows how to use the incremental state during decoding

In [10]:
class ParallelBeamSearchDecoder:
    def __init__(self, model, beam_size=5, max_len=20):
        self.model = model
        self.beam_size = beam_size
        self.max_len = max_len
        self.vocab_size = model.head.out_features

    def _init_incremental_state(self, batch_size=1):
        """Initialize incremental state similar to Fairseq's implementation"""
        incremental_state = defaultdict(dict)
        return incremental_state

    def decode(self, initial_input):
        """Perform parallel beam search decoding with proper incremental state"""
        # Initialize beams with proper incremental state
        beams = {
            'tokens': [initial_input[0].tolist()],  # List of token sequences
            'scores': torch.zeros(1, device=initial_input.device),  # (num_beams,)
            'parent_idx': None,  # For tracking beam origins
            'incremental_state': self._init_incremental_state()  # Properly initialized
        }

        for step in range(self.max_len):
            # Prepare input for all active beams
            num_beams = len(beams['tokens'])

            # Create input tensor: (num_beams, 1) - just the last token of each beam
            if step == 0:
              input_tensor = torch.tensor(
                  [seq for seq in beams['tokens']],
                  device=initial_input.device
              )  # (num_beams, 1)
            else:
              input_tensor = torch.tensor(
                  [seq[-1] for seq in beams['tokens']],
                  device=initial_input.device
              ).unsqueeze(1)  # (num_beams, 1)

            # Create position tensor
            if step == 0:
              positions = (
                  torch.arange(0, initial_input.size(1)+step, device=initial_input.device)
                  .float()
                  .view(1, -1)
                )
              positions = positions.repeat(num_beams, 1)
            else:
              positions = torch.tensor(
                  [[len(seq)-1] for seq in beams['tokens']],
                  device=initial_input.device,
                  dtype=torch.float
              )  # (num_beams, 1)

            # Forward pass with incremental state
            with torch.no_grad():
                logits = self.model(
                    input_tensor,
                    incremental_state = beams['incremental_state'],
                    positions = positions
                )  # (num_beams, 1, vocab_size)

            # Calculate scores for all possible continuations
            log_probs = F.log_softmax(logits[:, -1, :], dim=-1)  # (num_beams, vocab_size)

            # Combine with beam scores (broadcasting)
            candidate_scores = beams['scores'].unsqueeze(1) + log_probs  # (num_beams, vocab_size)

            # Flatten to get top candidates across all beams
            flat_scores = candidate_scores.view(-1)
            top_scores, top_indices = flat_scores.topk(self.beam_size)

            # Determine which beam and token each top candidate comes from
            beam_indices = top_indices // self.vocab_size
            token_indices = top_indices % self.vocab_size

            # Prepare new beams
            new_tokens = []
            new_scores = top_scores
            new_parent_idx = beam_indices

            # Build new token sequences
            for i, beam_idx in enumerate(beam_indices):
                new_seq = beams['tokens'][beam_idx] + [token_indices[i].item()]
                new_tokens.append(new_seq)

            # Reorder incremental state to match new beam order
            self.model.reorder_incremental_state(beams['incremental_state'], beam_indices)

            # Update beams for next iteration
            beams = {
                'tokens': new_tokens,
                'scores': new_scores,
                'parent_idx': new_parent_idx,
                'incremental_state': beams['incremental_state']
            }

            # Early stopping if all beams end with EOS (not implemented here)
            # For demo we'll just use max_len

        # Return best sequence (normalized by length)
        best_idx = torch.argmax(beams['scores'] / torch.tensor(
            [len(seq) for seq in beams['tokens']],
            device=beams['scores'].device
        ))
        return beams['tokens'][best_idx]

# Initialize decoder with parallel implementation
model.eval()
parallel_decoder = ParallelBeamSearchDecoder(model, beam_size=5, max_len=20)

# Create a test input
vocab_size = len(vocab)
inv_vocab = {v: k for k, v in vocab.items()}  # Inverse vocabulary mapping

# Initialize with prompt tokens
prompt="The quick brown"
initial_tokens = [vocab.get(c, vocab['<unk>']) for c in prompt]
test_input = torch.tensor([initial_tokens], device=device)

# Decode with parallel beam search
decoded_seq = parallel_decoder.decode(test_input)
test_out = ''.join([inv_vocab.get(t, '<unk>') for t in decoded_seq])
print("Decoded sequence:", test_out)

Decoded sequence: The quick brownd and and and thand 
