## Imports

In [73]:
import torch
from torch import nn
import networkx as nx
import matplotlib.pyplot as plt

## Medusa Heads

In [74]:
class BaseModelConfig:
    """
    Configuration class for the Base Model.
    """

    def __init__(self, hidden_size=4096, vocab_size=32000):
        self.hidden_size = hidden_size
        self.vocab_size = vocab_size
        self.tokenizer = None


class MedusaConfig(BaseModelConfig):
    """
    Configuration class for the Medusa model.

    Args:
        num_medusa_heads (int): The number of medusa heads.
        num_medusa_layers (int): The number of layers in the medusa head.
    """

    def __init__(self, num_medusa_heads=5, num_medusa_layers=1, **kwargs):
        super().__init__(**kwargs)
        self.num_medusa_heads = num_medusa_heads
        self.num_medusa_layers = num_medusa_layers
        

class ResBlock(nn.Module):
    """
    A Residual Block module.

    This module performs a linear transformation followed by a SiLU activation,
    and then adds the result to the original input, creating a residual connection.

    Args:
        hidden_size (int): The size of the hidden layers in the block.
    """

    def __init__(self, hidden_size):
        super().__init__()
        self.linear = nn.Linear(hidden_size, hidden_size)
        # Initialize as an identity mapping
        torch.nn.init.zeros_(self.linear.weight)
        # Use SiLU activation to keep consistent with the Llama model
        self.act = nn.SiLU()

    def forward(self, x):
        """
        Forward pass of the ResBlock.

        Args:
            x (torch.Tensor): Input tensor.

        Returns:
            torch.Tensor: Output after the residual connection and activation.
        """
        return x + self.act(self.linear(x))
    
class Medusa(nn.Module):
    """
    The head of the Medusa model.

    This module is responsible for the final output of the model, which is a single scalar value.

    Args:
        hidden_size (int): The size of the hidden layers in the model.
    """

    def __init__(self, config):
        super().__init__()
        self.vocab_size = config.vocab_size
        self.hidden_size = config.hidden_size
        self.num_medusa_heads = config.num_medusa_heads
        self.num_medusa_layers = config.num_medusa_layers

        self.medusa_heads = nn.ModuleList(
            [
                nn.Sequential(    
                    *([ResBlock(self.hidden_size)]*self.num_medusa_layers),
                    nn.Linear(self.hidden_size, self.vocab_size, bias=False)
                ) for _ in range(self.num_medusa_heads)
            ]
        )

    def forward(self, x):
        """
        Forward pass of the MedusaHead.

        Args:
            x (torch.Tensor): Input tensor.

        Returns:
            torch.Tensor: Output of the model.
        """
        # Forward pass through base model then through each head
        # ignoring the base model for now
        return torch.stack([head(x) for head in self.medusa_heads])

In [None]:
# Example usage --> feeding the last predicted token from base model
# Output from llama model B X pred_seq_len X hidden_size
hidden_size = 4096
llama_output = torch.randn(1, 10, hidden_size)

# Create a Medusa model
config = MedusaConfig(hidden_size=hidden_size)
medusa = Medusa(config)
# Initialize Weights
medusa_head_state_dict = torch.load('localpath../medusa_lm_head.pt') # huggingface weights --> FasterDecoding/medusa-vicuna-7b-v1.3/medusa_lm_head.pt
medusa.medusa_heads.load_state_dict(medusa_head_state_dict)

# Output from Medusa model num_medusa_heads X B X pred_seq_len X vocab_size
output = medusa(llama_output)
output.shape

## Medusa Buffer & Decoding

In [60]:
mc_sim_7b_63 = [[0], [0, 0], [1], [0, 1], [2], [0, 0, 0], [1, 0], [0, 2], [3], [0, 3], [4], [0, 4], [2, 0], [0, 5], [0, 0, 1], [5], [0, 6], [6], [0, 7], [0, 1, 0], [1, 1], [7], [0, 8], [0, 0, 2], [3, 0], [0, 9], [8], [9], [1, 0, 0], [0, 2, 0], [1, 2], [0, 0, 3], [4, 0], [2, 1], [0, 0, 4], [0, 0, 5], [0, 0, 0, 0], [0, 1, 1], [0, 0, 6], [0, 3, 0], [5, 0], [1, 3], [0, 0, 7], [0, 0, 8], [0, 0, 9], [6, 0], [0, 4, 0], [1, 4], [7, 0], [0, 1, 2], [2, 0, 0], [3, 1], [2, 2], [8, 0], [0, 5, 0], [1, 5], [1, 0, 1], [0, 2, 1], [9, 0], [0, 6, 0], [0, 0, 0, 1], [1, 6], [0, 7, 0]]
TOPK = 10

class MedusaBuffer:
    def __init__(self, medusa_choice=mc_sim_7b_63):
        self.medusa_choice = sorted(medusa_choice, key= lambda x: (len(x),x))
        self.get_medusa_attn_mask()
        self.get_medusa_position_ids()
        self.get_retrieve_indices()
        self.get_tree_indices()

    @staticmethod
    def pad_path(path, length, pad_value=-2):
        return path + [pad_value] * (length - len(path))
    
    def plot_attn_mask(self):
        plt.imshow(self.medusa_attn_mask)

    def plot_tree(self):
        plt.figure(figsize=(40, 20)) 
        paths = self.medusa_choice
        G = nx.DiGraph()
        for path in paths:
            for i in range(len(path)):
                if i == 0:
                    parent = 'root'
                else:
                    parent = tuple(path[:i])
                child = tuple(path[:i+1])
                G.add_edge(parent, child)

        # Use the Graphviz layout for drawing.
        pos = nx.nx_agraph.graphviz_layout(G, prog='dot')
        nx.draw(G, pos, with_labels=True, node_size=500, node_color="skyblue", font_size=10, width=2, edge_color="gray")
        plt.show()

    def to_dict(self, device=None):
        buffer_dict ={
            'medusa_attn_mask': self.medusa_attn_mask,
            'medusa_position_ids': self.medusa_position_ids,
            'tree_indices': self.tree_indices,
            'retrieve_indices': self.retrieve_indices
        }
        if device is not None:
            for k, v in buffer_dict.items():
                buffer_dict[k] = v.to(device)
        return buffer_dict

    def get_medusa_attn_mask(self):
        len_medusa = len(self.medusa_choice)+1
        self.medusa_attn_mask = torch.eye(len_medusa,len_medusa)
        self.medusa_attn_mask[:,0] = 1

        for i in range(len(self.medusa_choice)):
            if len(self.medusa_choice[i]) == 1:
                continue

            idx = self.medusa_choice.index(self.medusa_choice[i])
            self.medusa_attn_mask[i+1, idx+1] = 1
            for j in range(1, len(self.medusa_choice[i])):
                idx = self.medusa_choice.index(self.medusa_choice[i][:-j])
                self.medusa_attn_mask[i+1,idx+1] = 1
    
    def get_medusa_position_ids(self):
        self.medusa_position_ids = torch.tensor([len(choice) for choice in self.medusa_choice])
        self.medusa_position_ids = torch.cat([torch.zeros(1), self.medusa_position_ids])
    
    def get_tree_indices(self):
        medusa_len = len(self.medusa_choice) + 1
        depth_counts = torch.unique(self.medusa_position_ids[1:], return_counts=True)[1].tolist()
        self.tree_indices = torch.zeros(medusa_len, dtype=torch.long)
        self.tree_indices[0] = 0
        start = 0
        for i in range(len(depth_counts)):
            for j in range(depth_counts[i]):
                cur_medusa_choice = self.medusa_choice[start + j]
                self.tree_indices[start + j + 1] = cur_medusa_choice[-1] + TOPK * i + 1
            start += depth_counts[i]
    
    def get_retrieve_indices(self):
        visited_paths = []
        indices = []
        for idx, row in enumerate(reversed(self.medusa_attn_mask[1:,1:])):
            cur_medusa_choice = self.medusa_choice[-idx-1]
            if cur_medusa_choice in visited_paths:
                continue
            else:
                indices.append(torch.where(row == 1)[0].tolist())
                visited_paths += [cur_medusa_choice[: c + 1] for c in range(len(cur_medusa_choice))]
        max_length = max([len(x) for x in visited_paths])
        self.retrieve_indices = [MedusaBuffer.pad_path(path, max_length) for path in indices]
        self.retrieve_indices = torch.tensor(self.retrieve_indices, dtype=torch.long)
        self.retrieve_indices = self.retrieve_indices + 1
        self.retrieve_indices = torch.cat(
            [
                torch.zeros((self.retrieve_indices.shape[0], 1), dtype=torch.long),
                self.retrieve_indices,
            ],
            dim=1,
        )

In [71]:
class MedusaDecoding:
    def __init__(self, model=None, tokenizer=None):
        self.model = model
        self.tokenizer = tokenizer
        self.generate_medusa_buffer()

    def generate_medusa_buffer(self, device=None):
        self.medusa_buffer = MedusaBuffer().to_dict(device=device)

    def top_p_nucleus_sampling(self, logits, temperature, p):
        probs = torch.softmax(logits / temperature, dim=-1)
        probs_sort, probs_idx = torch.sort(probs, dim=-1, descending=True)
        probs_cumsum = torch.cumsum(probs_sort, dim=-1)
        valid_indices = probs_cumsum <= p
        valid_indices[..., 0] = True
        probs_sort[~valid_indices] = 0.0
        probs_sort.div_(probs_sort.sum(dim=-1, keepdim=True) + 1e-9)
        probs_sort = torch.clamp(probs_sort, min=0.0, max=1.0)
        next_token = torch.multinomial(probs_sort.view(-1, probs_sort.size(-1)), num_samples=1)
        next_token = next_token.view(probs_sort.shape[:-1] + (1,))
        next_token = torch.gather(probs_idx, -1, next_token)
        return next_token

    def generate_candidates(
            self,
            medusa_logits,
            logits,
            temperature=1,
            p=1,
    ):
        candidates_logit = self.top_p_nucleus_sampling(logits[0,-1], temperature, p)
        candidates_medusa_logits = torch.topk(medusa_logits[:, 0, -1], TOPK, dim=-1).indices
        candidates = torch.cat([candidates_logit, candidates_medusa_logits.view(-1)], dim=-1)

        # Map the combined candidates to the tree indices to get tree candidates.
        tree_candidates = candidates[self.medusa_buffer['tree_indices']]

        # Extend the tree candidates by appending a zero.
        tree_candidates_ext = torch.cat(
            [
                tree_candidates,
                torch.zeros((1), dtype=torch.long, device=tree_candidates.device),
            ],
            dim=0,
        )

        # Retrieve the cartesian candidates using the retrieve indices.
        cart_candidates = tree_candidates_ext[self.medusa_buffer['retrieve_indices']]
        # Unsqueeze the tree candidates for dimension consistency.
        tree_candidates = tree_candidates.unsqueeze(0)
        return cart_candidates, tree_candidates

    def tree_decoding(
            self,
            input_ids,
            model = None,
            tree_candidates = None,
            past_key_values = None,
    ):
        
        # Compute new position IDs by adding the Medusa position IDs to the length of the input sequence.
        position_ids = self.medusa_buffer['medusa_position_ids'] + input_ids.shape[1]

        # # Use the model to decode the tree candidates.
        # # The model is expected to return logits for the Medusa structure, original logits, and possibly other outputs.
        # tree_medusa_logits, outputs, tree_logits = model(
        #     tree_candidates,
        #     output_orig=True,
        #     past_key_values=past_key_values,
        #     position_ids=position_ids,
        #     medusa_forward=True,
        # )
        # currently randomly generating for illustration
        tree_logits = torch.randn(1, 64, 32000)
        tree_medusa_logits = torch.randn(5, 1, 64, 32000)

        # Reorder the obtained logits based on the retrieve_indices to ensure consistency with some reference ordering.
        logits = tree_logits[0, self.medusa_buffer['retrieve_indices']]
        medusa_logits = tree_medusa_logits[:, 0, self.medusa_buffer['retrieve_indices']]
        return medusa_logits, logits, None
    
    def evalulate_posterior(self, logits, candidates, temperature=1, p=0.8):
        logits = self.top_p_nucleus_sampling(logits, temperature, p)
        posterior_mask = (
                candidates[:, 1:] == logits.squeeze(-1)[:, :-1]
        ).int()
        candidates_accept_length = (torch.cumprod(posterior_mask, dim=1)).sum(dim=1)
        accept_length = candidates_accept_length.max()
        # Choose the best candidate
        if accept_length == 0:
            # Default to the first candidate if none are accepted
            best_candidate = torch.tensor(0, dtype=torch.long, device=candidates.device)
        else:
            best_candidate = torch.argmax(candidates_accept_length).to(torch.long)
        return best_candidate, accept_length

    def update_inference_inputs(self, best_candidate, accept_length, input_ids, cart_candidates, logits, medusa_logits, new_token, past_key_values_data=None, current_length_data=None):
        # Calculate the starting position for new tokens based on the previous input length
        prev_input_len = input_ids.shape[1]
        # Map the best candidate indices to the original indices in the sequence
        select_indices = (
            self.medusa_buffer['retrieve_indices'][best_candidate, : accept_length + 1] + prev_input_len
        )     
        # Append the tokens from the best candidate to the input sequence
        input_ids = torch.cat(
            [input_ids, cart_candidates[None, best_candidate, : accept_length + 1]], dim=-1
        )

        # # Update the past key values based on the selected tokens
        # tgt = past_key_values_data[..., select_indices, :]
        # dst = past_key_values_data[..., prev_input_len : prev_input_len + tgt.shape[-2], :]
        # dst.copy_(tgt, non_blocking=True)
        # current_length_data.fill_(prev_input_len + tgt.shape[-2])

        # Extract logits and medusa logits for the accepted tokens
        logits = logits[None, best_candidate, accept_length : accept_length + 1]
        medusa_logits = medusa_logits[
            :, None, best_candidate, accept_length : accept_length + 1
        ]
        # Update the new token counter
        new_token += accept_length + 1

        return input_ids, logits, medusa_logits, new_token

In [None]:
medusa_logits = torch.randn(5,1,10,4096)
logits = torch.randn(1,10,4096)
input_ids = torch.randint(0, 32000, (1, 10))

decoder = MedusaDecoding()
cart_candidates, tree_candidates = decoder.generate_candidates(medusa_logits, logits)
medusa_logits, logits, outputs = decoder.tree_decoding(input_ids)
best_candidate, accept_length = decoder.evalulate_posterior(logits, cart_candidates)
input_ids, logits, medusa_logits, new_token = decoder.update_inference_inputs(best_candidate, accept_length, input_ids, cart_candidates, logits, medusa_logits, 0)

print(medusa_logits.shape, logits.shape, input_ids.shape, new_token)