# Multi-Token Prediction (MTP)

## Step 0: Load Packages

In [None]:
import torch
import torch.nn as nn
import torch.nn.functional as F

## Step 1: Define RMSNorm Class

In [None]:
class RMSNorm(nn.Module):
  """Root Mean Square Layer Norm (no learning weights) """
  def __init__(self,d_model,eps:float = 1e-8):
    super().__init__()
    self.eps = eps

  def forward(self,x):
    # x: (batch,d_model)
    rms = torch.sqrt(x.pow(2).mean(dim=-1,keepdim=True)+ self.eps)
    return x / rms

## Step 2: Define the Multi-Token Prediction (MTP) class

In [None]:
class SimpleMTP(nn.Module):
  def __init__(self,d_model:int,vocab_size:int,num_heads:int=3,nhead: int =1):
    """
    d_model: hidden size (8 in this example)
    num_heads: number of sequential MTP steps (D)
    nhead: attention heads in each Transformer block
    """
    super().__init__()
    self.d_model = d_model
    self.vocab_size = vocab_size
    self.num_heads = num_heads

    # shared modules
    self.rmsnorm = RMSNorm(d_model)
    self.embed = nn.Embedding(vocab_size,d_model)
    self.unembed = nn.Linear(d_model,vocab_size,bias=False)
    # share weights between embed and unembed
    self.unembed.weight = self.embed.weight

    # one projection + one Transformer per head
    self.projections = nn.ModuleList([
        nn.Linear(2*d_model,d_model) for _ in range(num_heads)

    ])
    self.transformers = nn.ModuleList([
        nn.TransformerEncoderLayer(d_model=d_model,nhead=nhead)
        for _ in range(num_heads)
    ])

  def forward(self,token_ids:torch.LongTensor,init_hidden:torch.Tensor = None):
    """
    token_ids: (batch,seq_len) integer IDs of your input tokens
    init_hidden: optional (batch,seq_len,d_model) base hidden states;
                 if None, uses token embedding as initial hidden.

    Returns:
      logits_out: Tensor of shape (batch,T-D,D,vocab_size),
                  where T=seq_len and D=num_heads



    """

    B,T = token_ids.shape
    device = token_ids.device
    # token embeddings: (B,T,d_model)
    embeds = self.embed(token_ids)


    # base hidden states
    if init_hidden is None:
      h0_seq = embeds           # use embeddings as base hidden
    else:
      h0_seq = init_hidden    # user-provided base states

    outputs = [] # will hold (B,D,vocab_size) for each i
    # slide over positions where i + D < T
    max_i = T - self.num_heads - 1
    for i in range(0,max_i + 1):
      # previous hidden for depth 0 at pos i
      h_prev = h0_seq[:,i,:] # (B,d_model)


      # collect logits for all k at this i

      logits_k = []

      for k in range(self.num_heads):
        # future token embed at pos i + (k+ 1)
        future_pos = i + (k+1)
        tok_embed = embeds[:,future_pos,:] # (B,d_model)

        # 1) RMS-normalize
        h_norm = self.rmsnorm(h_prev) # (B,d_model)
        e_norm = self.rmsnorm(tok_embed) # (B,d_model)

        # 2) concatenate -> (B,2*d_model)
        merged = torch.cat([h_norm,e_norm],dim=-1)

        # 3) project back to d_model
        proj = self.projections[k](merged) # (B, d_model)

        # 4) Transformer block (expects shape (S,B,d_model))
        x = proj.unsqueeze(0)    # (1,B,d_model)
        x = self.transformers[k](x)  # (1,B,d_model)
        h_curr = x.squeeze(0)  # (B,d_model)

        # 5) unembed -> logits
        logits = self.unembed(h_curr)   # (B,vocab_size)
        logits_k.append(logits)

        # 6) chain hidden for next depth
        h_prev = h_curr

      # stack along. depth axis -> (B,D,vocab_size)
      logits_k = torch.stack(logits_k,dim=1)
      outputs.append(logits_k)

    # stack along sequence axis -> (T-D,B,D,V) then permute -> (B,T-D,D,V)

    out = torch.stack(outputs,dim=0)
    out = out.permute(1,0,2,3).contiguous()
    return out

## Step 3: Pass input tokens through  the model and generate multiple next tokens.

In [None]:
batch_size, seq_len,d_model,vocab_size = 1,8,8,5000
model = SimpleMTP(d_model=d_model,vocab_size=vocab_size,num_heads=3)
tokens = torch.randint(0,vocab_size,(batch_size,seq_len))


# Forward pass
logits = model(tokens)
# logits.shape == (1,4-3,3,5000) -> (batch_size,T-D,D,V)
print("Logits shape:",logits.shape)

# If you want to inspect the 1-step ahead predition at postition i=0:
print("Head k=0 at i=0 logits:",logits[0,0,0]) # a tensor of length vocab_size

# Or to get all predictions at i=0 as token IDs:

pred_ids = logits[0,0].argmax(dim=-1)
print("Predicted tokens at i=0 for all heads:",pred_ids)  # a length-3 tensor

Logits shape: torch.Size([1, 5, 3, 5000])
Head k=0 at i=0 logits: tensor([ 3.6052,  2.8964, -1.7114,  ...,  1.4961, -3.3179,  1.1599],
       grad_fn=<SelectBackward0>)
Predicted tokens at i=0 for all heads: tensor([4207, 4708, 4765])


## Step 4: Calcuate loss betweeen Loss between target tokens and predicted tokens

In [None]:
batch_size, seq_len, vocab_size = 1,8,5000

# old (wrong): targets = torch.randint(0, vocab_size,(1,4))
# new (right):

targets = torch.randint(0,vocab_size,(batch_size,seq_len))
print("targets.shape ->",targets.shape) # torch.Size([1,8])


# Now recompute:

logits = model(tokens)   # shape (1,5,3,5000)
B,L,D,V = logits.shape    # (1,5,3,5000)
_,T = targets.shape  # (1,8)
assert L == T - D     # 5 == 8 -3 passes


# Double-loop loss:
loss = 0.0
for i in range(L):
  for k in range(D):    # i = 0...4
    logits_ik = logits[:,i,k,:]   # (1,5000)
    target_ik = targets[:,i + (k + 1)] # (1,)
    loss += F.cross_entropy(logits_ik,target_ik)

loss = loss / (L*D)
print("MTP loss:",loss.item())




targets.shape -> torch.Size([1, 8])
MTP loss: 13.472195625305176
