In [2]:
import torch
from llama.model import Transformer, ModelArgs

Hooks don't allow me to properly alter the positional embedding, as the positional embedding is not a separate module. Instead, I will alter the positional embedding directly in the implementation, which is why I have to fork llama's repo. Modifying it in the instantiation makes it non-general, same for modifying it in the rotational embedding application. So I will alter the representation directly in the forward, and make it dependent on an argument.

Open question: How do I want to alter the rotary positional embedding?

Then, I will make a new version of Llama from llama.generation called AlteredLlama, allowing to load llama from a file using an AlteredTransformer is created instead. I will likely need to overwrite all methods in order to pass the right arguments to the forward method of AlteredTransformer. This is really inconvenient.

Do I have a cleaner way of doing it?

I could instead create a mode for the AlteredTransformer that fixes whether the contexts must be reset or not, and using which indices. In this case, I still need to create AlteredLlama, but I can just override build for the loading as AlteredTransformer and define a method for AlteredLlama to switch the mode in the transformer. All generation code remains the same. For additional simplicity, I can create super-methods for each generation method where I first properly call the mode switch according to the prompt, then the desired generation method. I like that, I'll do this latter version.

Note: An advantage of my implementation is that it can be used on top of the llama library, instead of requiring manipulations inside the library code.

In [None]:
class AlteredTransformer(Transformer):
    def  __init__(self, params: ModelArgs):
        super().__init__(params)
        self.alteration_mode = None
        self.alteration_kwargs = dict()

    def switch_mode(self, mode:list=None, **kwargs):
        '''
        Switch the model's alteration mode.
        
        @param mode: Which alteration mode to use. Valid values are {None, "median", "reset"}.
        @param kwargs: Any kwargs needed for the specified alteration mode.

        @returns None
        '''
        assert mode is None or mode in {"median", "reset"}, "Invalid mode provided"
        # Note: I don't think it makes any sense to implement zero-patching to replace rotary positional embedding. But I can always add it later if we want to try it as well.

        error_msg = "Provide {} argument for mode {}."

        # Asserting the validity of additional arguments for each mode
        if mode in {"median", "reset"}:
            indices = kwargs.get("indices", None)
            assert indices is not None, error_msg.format("indices", mode)

            if mode == "reset":
                assert isinstance(indices, list) and (len(indices) > 1), "Indices must be a list of at least two elements."
            if mode == "median":
                assert isinstance(indices, tuple) and (len(indices) == 2), "Indices must be a tuple of two elements."
            
            previous = 0
            for i in indices:
                assert isinstance(i, int) and i >= 0, "Indices provided must be non-negative integers"
                assert i > previous, "Each index must be greater than the previous one."
                previous = i
        
        self.alteration_mode = mode
        self.alteration_kwargs = kwargs

    def alter_positional_embedding(self, freqs_cis):
        '''
        Alter the positional embedding using the approach specified in self.alteration_mode.

        @param freqs_cis: Frequencies to alter.

        @returns torch.Tensor of the same shape as freqs_cis, which are the new positional embeddings to use.
        '''
        if self.alteration_mode is None:
            return freqs_cis
        if self.alteration_mode == "median":
            raise NotImplementedError  # TODO
        if self.alteration_mode == "reset":
            raise NotImplementedError  # TODO

    @torch.inference_mode()
    def forward(self, tokens: torch.Tensor, start_pos: int):  # Let's overwrite the standard Transformer forward
        """
        Perform a forward pass through the Transformer model.

        Args:
            tokens (torch.Tensor): Input token indices.
            start_pos (int): Starting position for attention caching.

        Returns:
            torch.Tensor: Output logits after applying the Transformer model.

        """
        _bsz, seqlen = tokens.shape
        h = self.tok_embeddings(tokens)
        self.freqs_cis = self.freqs_cis.to(h.device)
        freqs_cis = self.freqs_cis[start_pos : start_pos + seqlen]

        freqs_cis = self.alter_positional_embedding(freqs_cis)  # Modification compared to the forward of the superclass

        mask = None
        if seqlen > 1:
            mask = torch.full(
                (seqlen, seqlen), float("-inf"), device=tokens.device
            )

            mask = torch.triu(mask, diagonal=1)

            # When performing key-value caching, we compute the attention scores
            # only for the new sequence. Thus, the matrix of scores is of size
            # (seqlen, cache_len + seqlen), and the only masked entries are (i, j) for
            # j > cache_len + i, since row i corresponds to token cache_len + i.
            mask = torch.hstack([
                torch.zeros((seqlen, start_pos), device=tokens.device),
                mask
            ]).type_as(h)

        for layer in self.layers:
            h = layer(h, start_pos, freqs_cis, mask)
        h = self.norm(h)
        output = self.output(h).float()
        return output