<a href="https://colab.research.google.com/github/AnanthSankaralingam/kv-cache-research/blob/main/HyperAttention_for_Longbench.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

## imports and setup

In [3]:
# install necessary packages
!pip uninstall transformers -y  # uninstall any existing transformers package
!pip install triton==2.0.0 --no-deps # TAKEN FROM README, not sure why .dev version doesnt exist

[0mCollecting triton==2.0.0
  Downloading triton-2.0.0-1-cp310-cp310-manylinux2014_x86_64.manylinux_2_17_x86_64.whl.metadata (1.0 kB)
Downloading triton-2.0.0-1-cp310-cp310-manylinux2014_x86_64.manylinux_2_17_x86_64.whl (63.3 MB)
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m63.3/63.3 MB[0m [31m12.3 MB/s[0m eta [36m0:00:00[0m
[?25hInstalling collected packages: triton
Successfully installed triton-2.0.0


In [4]:
!pip install --upgrade transformers einops -q  # install the latest versions of transformers, einops, and triton

In [5]:
# clone the hyper-attention repository to access additional required files
!git clone https://github.com/insuhan/hyper-attn.git

# change directory to the cloned repository
%cd hyper-attn
!ls

Cloning into 'hyper-attn'...
remote: Enumerating objects: 33, done.[K
remote: Counting objects: 100% (21/21), done.[K
remote: Compressing objects: 100% (14/14), done.[K
remote: Total 33 (delta 10), reused 7 (delta 7), pack-reused 12 (from 1)[K
Receiving objects: 100% (33/33), 108.29 KiB | 21.66 MiB/s, done.
Resolving deltas: 100% (11/11), done.
/content/hyper-attn/hyper-attn
assets	benchmark_patch_llm.py	benchmark_single_attention.py  LICENSE	models	README.md


In [6]:
# imports
import os
import sys
import torch
import torch.nn as nn
import torch.nn.functional as F
import numpy as np
from tqdm import tqdm  # progress bar for loops
from torch.nn import CrossEntropyLoss  # loss function for classification tasks

import transformers
print("Transformers version:", transformers.__version__)

# import model and tokenizer classes from transformers
from transformers import AutoModelForCausalLM, AutoTokenizer
from transformers.models.llama.modeling_llama import LlamaAttention, LlamaRotaryEmbedding

os.environ['HUGGING_FACE_HUB_TOKEN'] = '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'

# add the hyper-attn directory to the python path
sys.path.append('/content/hyper-attn')
sys.path.append('/content/hyper-attn/models')

Transformers version: 4.46.3


## POC testing hyper attention through direct class declaration

In [16]:
from models.attention.hyper_attn import HyperAttention

attn = HyperAttention(
    input_dim=64,
    lsh_num_projs=7,
    block_size=256,
    sample_size=256,
    min_seq_len=4096)

# dummy dimensions
batch_size = 2
seq_len = 4096
input_dim = 64
n_heads = 8 # num attention heads

# dummy tensors for q k v
query = torch.rand(batch_size, n_heads, seq_len, input_dim)
key = torch.rand(batch_size, n_heads, seq_len, input_dim)
value = torch.rand(batch_size, n_heads, seq_len, input_dim)

# Forward pass
attn_output = attn(query, key, value, True)

print("Query shape:", query.shape)
print("Key shape:", key.shape)
print("Value shape:", value.shape)
print("Attention output shape:", attn_output.shape)

Query shape: torch.Size([2, 8, 4096, 64])
Key shape: torch.Size([2, 8, 4096, 64])
Value shape: torch.Size([2, 8, 4096, 64])
Attention output shape: torch.Size([2, 8, 4096, 64])


## Update hyperattn to not use flash attn

In [None]:
import math

def exact_attention(q, k, v, softmax_scale, causal=False, bias=None):
    # compute the attention scores
    qk = torch.matmul(q, k.transpose(-2, -1)) * softmax_scale
    if bias is not None:
        qk = qk + bias
    if causal:
        # apply causal mask to prevent attention to future tokens
        mask = torch.tril(torch.ones(qk.shape[-2], qk.shape[-1], device=qk.device)).unsqueeze(0).unsqueeze(0)
        qk = qk.masked_fill(mask == 0, float('-inf'))
    # apply softmax to get attention probabilities
    qk = torch.softmax(qk, dim=-1)
    # compute the output by weighting the values with attention probabilities
    output = torch.matmul(qk, v)
    return output, None

# define utility functions for rotary embeddings
def rotate_half(x):
    # split the last dimension into two halves and concatenate them with flipped signs
    x1 = x[..., : x.shape[-1] // 2]
    x2 = x[..., x.shape[-1] // 2 :]
    return torch.cat((-x2, x1), dim=-1)

def apply_rotary_pos_emb(q, k, cos, sin):
    # apply rotary positional embeddings to query and key tensors
    q_embed = (q * cos) + (rotate_half(q) * sin)
    k_embed = (k * cos) + (rotate_half(k) * sin)
    return q_embed, k_embed

    # define the custom attention class using HyperAttention
class LlamaHyperAttention(nn.Module):
    def __init__(self, config):
        super().__init__()

        self.hidden_size = config.hidden_size  # total hidden size
        self.num_heads = config.num_attention_heads  # number of attention heads
        self.head_dim = self.hidden_size // self.num_heads  # dimension per head

        if (self.head_dim * self.num_heads) != self.hidden_size:
            raise ValueError(
                f"hidden_size must be divisible by num_attention_heads (got `hidden_size`: {self.hidden_size} and "
                f"`num_attention_heads`: {self.num_heads})."
            )

        # linear projections for query, key, value, and output
        self.o_proj = nn.Linear(self.hidden_size, self.hidden_size, bias=False)
        self.q_proj = nn.Linear(self.hidden_size, self.hidden_size, bias=False)
        self.k_proj = nn.Linear(self.hidden_size, self.hidden_size, bias=False)
        self.v_proj = nn.Linear(self.hidden_size, self.hidden_size, bias=False)

        # rotary positional embeddings
        self.rotary_emb = LlamaRotaryEmbedding(
            self.head_dim,
            max_position_embeddings=config.max_position_embeddings,
            base=config.rope_theta,
            rope_scaling=config.rope_scaling,
        )

        # initialize hyper attention
        self.hyper_attn = HyperAttention(
            input_dim=self.head_dim,
            lsh_num_projs=getattr(config, 'lsh_num_projs', 7),
            block_size=getattr(config, 'block_size', 256),
            sample_size=getattr(config, 'sample_size', 256),
            min_seq_len=getattr(config, 'min_seq_len', 4096),
            cuda=False,
        )

    def forward(
        self,
        hidden_states,
        attention_mask=None,
        position_ids=None,
        past_key_value=None,
        output_attentions=False,
        use_cache=False,
        **kwargs
    ):
        bsz, seq_len, _ = hidden_states.size()
        device = hidden_states.device

        # ensure consistent data types
        hidden_states = hidden_states.to(dtype=self.q_proj.weight.dtype)

        # project hidden_states to query, key, and value tensors
        query_states = self.q_proj(hidden_states)
        key_states = self.k_proj(hidden_states)
        value_states = self.v_proj(hidden_states)

        # reshape and split into heads
        query_states = query_states.view(bsz, seq_len, self.num_heads, self.head_dim).transpose(1, 2)
        key_states = key_states.view(bsz, seq_len, self.num_heads, self.head_dim).transpose(1, 2)
        value_states = value_states.view(bsz, seq_len, self.num_heads, self.head_dim).transpose(1, 2)

        # apply rotary embeddings
        cos, sin = self.rotary_emb(query_states, seq_len=seq_len)
        query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin)

        # use hyper attention
        attn_output, _ = self.hyper_attn(
            query_states,
            key_states,
            value_states,
            scale=None,
            causal=True,
            return_lse=False
        )

        # merge heads
        attn_output = attn_output.transpose(1, 2).contiguous().view(bsz, seq_len, self.hidden_size)

        # project output
        attn_output = self.o_proj(attn_output)

        if output_attentions:
            # hyper attention does not return attention weights
            attn_weights = None
            outputs = (attn_output, attn_weights)
        else:
            outputs = (attn_output,)

        if use_cache:
            # return key and value states for caching
            present_key_value = (key_states, value_states)
            outputs = outputs + (present_key_value,)

        return outputs

## Replace attention modules in llama 3

In [21]:
# function to replace modules in the model
def replace_module(model, module_name, new_module):
    # replace a module in the model with a new module
    components = module_name.split('.')
    attr = components[-1]
    parent = model
    for comp in components[:-1]:
        parent = getattr(parent, comp)
    setattr(parent, attr, new_module)

# function to patch attention layers in the model
def patch_attention_layers(model, **kwargs):
    # replace LlamaAttention modules with LlamaHyperAttention modules
    for name, module in model.named_modules():
        if isinstance(module, LlamaAttention):
            # create a new LlamaHyperAttention module
            new_module = LlamaHyperAttention(model.config)
            # copy weights from the original attention module
            new_module.q_proj.weight.data = module.q_proj.weight.data
            new_module.k_proj.weight.data = module.k_proj.weight.data
            new_module.v_proj.weight.data = module.v_proj.weight.data
            new_module.o_proj.weight.data = module.o_proj.weight.data
            # replace the attention module in the model
            replace_module(model, name, new_module)

# function to get the model and tokenizer
def get_model_and_tokenizer(model_name):
    if model_name == "llama-3.1-8b-instruct":
        # load the tokenizer
        tokenizer = AutoTokenizer.from_pretrained(
            "meta-llama/Llama-3.1-8B-Instruct",
            use_auth_token=os.environ['HUGGING_FACE_HUB_TOKEN'],
            trust_remote_code=True
        )
        # load the model
        model = AutoModelForCausalLM.from_pretrained(
            "meta-llama/Llama-3.1-8B-Instruct",
            torch_dtype=torch.float16,
            device_map='auto',
            use_auth_token=os.environ['HUGGING_FACE_HUB_TOKEN'],
            trust_remote_code=True
        )
    else:
        raise NotImplementedError("Currently we only support llama-3.1-8b-instruct")
    return model, tokenizer

In [1]:
def main():
    # define arguments
    class Args:
        seq_len = 1024  # adjust based on your GPU memory
        patch_config = 'last'
        attn_method = 'hyper'
        num_patch_layers = -1
        block_size = 256
        sample_size = 256
        lsh_num_projs = 7
        min_seq_len = 4096
        model_name = 'llama-3.1-8b-instruct'

    args = Args()
    for arg_name, arg_var in vars(args).items():
        print(f"{arg_name:<16} : {arg_var}")

    # load the model and tokenizer
    model, tokenizer = get_model_and_tokenizer(args.model_name)
    tokenizer.model_max_length = args.seq_len
    device = "cuda" if torch.cuda.is_available() else "cpu"
    dtype = torch.float16  # use float16 to save memory

    # prepare dummy data for testing (replace with your dataset if needed)
    data = [{"context": "This is a test sentence. " * (args.seq_len // 5)} for _ in range(1)]  # adjust sequence length
    encoded_texts = []

    for data_i in data:
        encoded_text = tokenizer(
            data_i['context'],
            return_tensors='pt',
            truncation=True,
            max_length=args.seq_len
        )['input_ids']
        if encoded_text.size(1) < args.seq_len:
            continue
        encoded_texts.append(encoded_text)

    print(f"# of data longer than {args.seq_len}: {len(encoded_texts)}")

    # patch the attention layers with LlamaHyperAttention
    if args.attn_method != 'flash':
        patch_attention_layers(model=model, **vars(args))

    model.eval()
    loss_fct = CrossEntropyLoss(reduction="none")

    ppls = []

    pbar = tqdm(range(len(encoded_texts)))
    for bid in pbar:
        encoded_batch = encoded_texts[bid].to(device)
        attn_mask = torch.ones_like(encoded_batch)

        with torch.no_grad():
            outputs = model(input_ids=encoded_batch, attention_mask=attn_mask, use_cache=False)
            out_logits = outputs.logits

        labels = encoded_batch

        shift_logits = out_logits[..., :-1, :].contiguous()
        shift_labels = labels[..., 1:].contiguous()
        shift_attention_mask_batch = attn_mask[..., 1:].contiguous()

        loss_ = loss_fct(shift_logits.view(-1, shift_logits.size(-1)), shift_labels.view(-1)).float()
        loss_ = loss_.view(shift_labels.size())
        perplexity_batch = torch.exp2(
            (loss_ * shift_attention_mask_batch).sum(1)
            / shift_attention_mask_batch.sum(1)
        )
        ppls += perplexity_batch.tolist()

        avg_ppl = np.mean([p for p in ppls if not np.isnan(p)])
        pbar.set_description(f"[{bid + 1}/{len(encoded_texts)}] avg_ppl: {avg_ppl:.4f}")

        # clear variables to free memory
        del outputs, out_logits, encoded_batch, attn_mask, shift_logits, shift_labels, shift_attention_mask_batch, loss_, perplexity_batch
        torch.cuda.empty_cache()

    nan_cnt = sum(np.isnan(np.array(ppls)))
    ppl_mean = np.mean(np.array(ppls)[~np.isnan(np.array(ppls))])

    print(f"Perplexity: {ppl_mean}, NaN count: {nan_cnt}")
    res_str = f"Model: {args.model_name}, dtype: {dtype}, seq_len: {args.seq_len}, num_patch_layers: {args.num_patch_layers}, n_data: {len(encoded_texts)}, Perplexity: {ppl_mean}, NaN count: {nan_cnt}\n"
    print(res_str)

# run the main function
main()

NameError: name 'get_model_and_tokenizer' is not defined