In [None]:
!pip install transformers accelerate safetensors bitsandbytes attention_sinks -U

In [2]:
import time
import types
import math
import torch
from transformers import AutoTokenizer, AutoModelForCausalLM

In [None]:
from transformers import AutoTokenizer, TextStreamer, GenerationConfig
from attention_sinks import AutoModel

model_id = "mistralai/Mistral-7B-v0.1"
model = AutoModelForCausalLM.from_pretrained(
    model_id,
    load_in_4bit=True,
    device_map="auto",
    torch_dtype=torch.float16,
    attention_sink_size=4,
    attention_sink_window_size=252,
)

model.eval()
tokenizer = AutoTokenizer.from_pretrained(model_id)
tokenizer.pad_token_id = tokenizer.eos_token_id

# Our input text
text = "Vaswani et al. (2017) introduced the Transformers"

# Encode the text
input_ids = tokenizer.encode(text, return_tensors="pt").to(model.device)

with torch.no_grad():
    # A TextStreamer prints tokens as they're being generated
    streamer = TextStreamer(tokenizer)

    generated_tokens = model.generate(
        input_ids,
        generation_config=GenerationConfig(
            max_length=100,
            use_cache=True,
            min_new_tokens=100_000,
            max_new_tokens=1_000_000,
            penalty_alpha=0.6,
            top_k=5,
            pad_token_id=tokenizer.pad_token_id,
            eos_token_id=tokenizer.eos_token_id,
        ),
        streamer=streamer,
    )
    # Decode the final generated text
    output_text = tokenizer.decode(generated_tokens[0], skip_special_tokens=True)

print(output_text)

Downloading (…)lve/main/config.json:   0%|          | 0.00/571 [00:00<?, ?B/s]

Downloading (…)model.bin.index.json:   0%|          | 0.00/23.9k [00:00<?, ?B/s]

Downloading shards:   0%|          | 0/2 [00:00<?, ?it/s]

Downloading (…)l-00001-of-00002.bin:   0%|          | 0.00/9.94G [00:00<?, ?B/s]

Downloading (…)l-00002-of-00002.bin:   0%|          | 0.00/5.06G [00:00<?, ?B/s]

Loading checkpoint shards:   0%|          | 0/2 [00:00<?, ?it/s]

# 1) Original Transformer Model

In [16]:
### load model ###
model_name_or_path = "NousResearch/Llama-2-7b-hf"

tokenizer = AutoTokenizer.from_pretrained(
        model_name_or_path,
        trust_remote_code=True,
)
tokenizer.pad_token_id = tokenizer.eos_token_id
model = AutoModelForCausalLM.from_pretrained(
    model_name_or_path,
    device_map="auto",
    torch_dtype=torch.float16,
    trust_remote_code=True,
    load_in_8bit=True,
)

# from attention_sinks import AutoModelForCausalLM
# model = AutoModelForCausalLM.from_pretrained(
#     model_name,
#     trust_remote_code=True,
#     torch_dtype=torch.float16,
#     device_map="auto",
#     attention_sink_size=4,
#     attention_sink_window_size=512,
# )

Downloading (…)okenizer_config.json:   0%|          | 0.00/749 [00:00<?, ?B/s]

Downloading tokenizer.model:   0%|          | 0.00/500k [00:00<?, ?B/s]

Downloading (…)/main/tokenizer.json:   0%|          | 0.00/1.84M [00:00<?, ?B/s]

Downloading (…)cial_tokens_map.json:   0%|          | 0.00/411 [00:00<?, ?B/s]

Downloading (…)lve/main/config.json:   0%|          | 0.00/646 [00:00<?, ?B/s]

Downloading (…)fetensors.index.json:   0%|          | 0.00/25.1k [00:00<?, ?B/s]

Downloading shards:   0%|          | 0/2 [00:00<?, ?it/s]

Downloading (…)of-00002.safetensors:   0%|          | 0.00/9.98G [00:00<?, ?B/s]

Downloading (…)of-00002.safetensors:   0%|          | 0.00/3.50G [00:00<?, ?B/s]

Loading checkpoint shards:   0%|          | 0/2 [00:00<?, ?it/s]

Downloading (…)neration_config.json:   0%|          | 0.00/116 [00:00<?, ?B/s]

In [17]:
@torch.no_grad()
def greedy_generate(model, tokenizer, prompt, kv_cache, max_new_tokens=1000):
    past_key_values = None
    prompt = f"[INST] {prompt} [/INST]"
    # prompt = tokenizer.apply_chat_template([{"role": "user", "content": prompt}], tokenize=False)
    input_ids = tokenizer(prompt, return_tensors="pt").input_ids
    input_ids = input_ids.to(model.device)
    seq_len = input_ids.shape[1]

    if kv_cache is not None:
        space_needed = seq_len + max_new_tokens
        past_key_values = kv_cache.evict_for_space(past_key_values, space_needed)
    pred_out = []

    for _ in range(max_new_tokens):
        # With the cache, the model saves the hidden state once it has been computed,
        # and only computes the one for the most recently generated output token at each time step,
        # re-using the saved ones for hidden tokens
        outputs = model(input_ids, past_key_values=past_key_values, use_cache=True)
        past_key_values = outputs.past_key_values
        pred_token_idx = outputs.logits[:, -1, :].argmax(dim=-1).unsqueeze(1)
        pred_out.append(pred_token_idx.item())
        input_ids = pred_token_idx

        if pred_token_idx == tokenizer.eos_token_id:
            break

    generated_text = tokenizer.decode(
            pred_out, skip_special_tokens=True,
            clean_up_tokenization_spaces=True,
            spaces_between_special_tokens=False,
            )

    return generated_text

In [None]:
prompt = 'Write me a python script for transformer decoder architecture from scratch'
start = time.time()
res = greedy_generate(model, tokenizer, prompt, kv_cache=None, max_new_tokens=1000)
print(f"Output: {res}")
print(f"Time Elapsed: {int(time.time() - start)}")

# 2) Attention Sink

In [None]:
from transformers.models.llama.modeling_llama import LlamaAttention
import torch.nn.functional as F

def repeat_kv(hidden_states: torch.Tensor, n_rep: int) -> torch.Tensor:
    """
    The hidden states go from (batch, num_key_value_heads, seqlen, head_dim) to
    (batch, num_attention_heads, seqlen, head_dim)
    """
    batch, num_key_value_heads, slen, head_dim = hidden_states.shape
    if n_rep == 1:
        return hidden_states
    hidden_states = hidden_states[:, :, None, :, :].expand(batch, num_key_value_heads, n_rep, slen, head_dim)
    return hidden_states.reshape(batch, num_key_value_heads * n_rep, slen, head_dim)

def rotate_half(x):
    """Rotates half the hidden dims of the input."""
    x1 = x[..., : x.shape[-1] // 2]
    x2 = x[..., x.shape[-1] // 2 :]
    return torch.cat((-x2, x1), dim=-1)

def apply_rotary_pos_emb_single(x, cos, sin, position_ids):
    # The first two dimensions of cos and sin are always 1, so we can `squeeze` them.
    cos = cos.squeeze(1).squeeze(0)  # [seq_len, dim]
    sin = sin.squeeze(1).squeeze(0)  # [seq_len, dim]
    cos = cos[position_ids].unsqueeze(1)  # [bs, 1, seq_len, dim]
    sin = sin[position_ids].unsqueeze(1)  # [bs, 1, seq_len, dim]
    x_embed = (x * cos) + (rotate_half(x) * sin)
    return x_embed

def llama_pos_shift_attention_forward(
    self, hidden_states, attention_mask, position_ids, past_key_value, output_attentions, use_cache):

    bsz, q_len, _ = hidden_states.size()
    tp = self.config.pretraining_tp

    if tp > 1:
        key_value_slicing = (self.num_key_value_heads * self.head_dim) // tp
        query_slices = self.q_proj.weight.split((self.num_heads * self.head_dim) // tp, dim=0)
        key_slices = self.k_proj.weight.split(key_value_slicing, dim=0)
        value_slices = self.v_proj.weight.split(key_value_slicing, dim=0)

        query_states = [F.linear(hidden_states, query_slices[i]) for i in range(tp)]
        query_states = torch.cat(query_states, dim=-1)
        key_states = [F.linear(hidden_states, key_slices[i]) for i in range(tp)]
        key_states = torch.cat(key_states, dim=-1)
        value_states = [F.linear(hidden_states, value_slices[i]) for i in range(tp)]
        value_states = torch.cat(value_states, dim=-1)

    else:
        query_states = self.q_proj(hidden_states)
        key_states = self.k_proj(hidden_states)
        value_states = self.v_proj(hidden_states)

    query_states = query_states.view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2)
    key_states = key_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)
    value_states = value_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)

    kv_seq_len = key_states.shape[-2]
    if past_key_value is not None:
        kv_seq_len += past_key_value[0].shape[-2]
    cos, sin = self.rotary_emb(value_states, seq_len=kv_seq_len)

    ### Shift Pos: query pos is min(cache_size, idx)
    # query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin, position_ids)
    query_states = apply_rotary_pos_emb_single(query_states, cos, sin, position_ids)

    if past_key_value is not None:
        # reuse k, v, self_attention
        key_states = torch.cat([past_key_value[0], key_states], dim=2)
        value_states = torch.cat([past_key_value[1], value_states], dim=2)

    past_key_value = (key_states, value_states) if use_cache else None

    ### Shift Pos: key pos is the pos in cache
    key_position_ids = torch.arange(kv_seq_len, device=position_ids.device).unsqueeze(0)
    key_states = apply_rotary_pos_emb_single(key_states, cos, sin, key_position_ids)

    # repeat k/v heads if n_kv_heads < n_heads
    key_states = repeat_kv(key_states, self.num_key_value_groups)
    value_states = repeat_kv(value_states, self.num_key_value_groups)

    attn_weights = torch.matmul(query_states, key_states.transpose(2, 3)) / math.sqrt(self.head_dim)
    if attention_mask is not None:
        attn_weights = attn_weights + attention_mask

    # upcast attention to fp32
    attn_weights = torch.nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32)
    attn_weights = attn_weights.to(query_states.dtype)
    attn_output = torch.matmul(attn_weights, value_states)
    attn_output = attn_output.transpose(1, 2).contiguous()
    attn_output = attn_output.reshape(bsz, q_len, self.hidden_size)

    if tp > 1:
        attn_output = attn_output.split(self.hidden_size // tp, dim=2)
        o_proj_slices = self.o_proj.weight.split(self.hidden_size // tp, dim=1)
        attn_output = sum([F.linear(attn_output[i], o_proj_slices[i]) for i in range(tp)])
    else:
        attn_output = self.o_proj(attn_output)

    if not output_attentions:
        attn_weights = None

    return attn_output, attn_weights, past_key_value

def enable_llama_pos_shift_attention(model):
    for name, module in reversed(model._modules.items()):
        if len(list(module.children())) > 0:
            enable_llama_pos_shift_attention( module)

        if isinstance(module, LlamaAttention):
            model._modules[name].forward = types.MethodType(
                llama_pos_shift_attention_forward, model._modules[name])

enable_llama_pos_shift_attention(model)

In [None]:
prompt = 'Write me a python script for transformer decoder architecture from scratch'
start = time.time()
res = greedy_generate(model, tokenizer, prompt, kv_cache=None, max_new_tokens=1000)
print(f"Output: {res}")
print(f"Time Elapsed: {int(time.time() - start)}")

# 3) KV Caching

[Key Value Caching](https://medium.com/@joaolages/kv-caching-explained-276520203249)

In [None]:
import numpy as np
import time
import torch
from transformers import AutoModelForCausalLM, AutoTokenizer

device = "cuda" if torch.cuda.is_available() else "cpu"
tokenizer = AutoTokenizer.from_pretrained("gpt2")
model = AutoModelForCausalLM.from_pretrained("gpt2").to(device)

for use_cache in (True, False):
    times = []
    for _ in range(10):  # measuring 10 generations
        start = time.time()
        model.generate(**tokenizer("What is KV caching?", return_tensors="pt").to(device), use_cache=use_cache, max_new_tokens=1000)
        times.append(time.time() - start)

    print(f"{'with' if use_cache else 'without'} KV caching: {round(np.mean(times), 3)} +- {round(np.std(times), 3)} seconds")

In [None]:
def slice2d(x, start, end): return x[:, :, start:end, ...]
def slice3d(x, start, end): return x[:, :, :, start:end, ...]
def slice1d(x, start, end): return x[:, start:end, ...]
DIM_TO_SLICE = {1: slice1d, 2: slice2d, 3: slice3d}

class StartRecentKVCache:
    def __init__(
        self, start_size=4, recent_size=512, k_seq_dim=2, v_seq_dim=2):

        print(f"StartRecentKVCache: {start_size}, {recent_size}")
        self.start_size = start_size
        self.recent_size = recent_size
        self.cache_size = start_size + recent_size
        self.k_seq_dim = k_seq_dim
        self.v_seq_dim = v_seq_dim
        self.k_slice = DIM_TO_SLICE[k_seq_dim]
        self.v_slice = DIM_TO_SLICE[v_seq_dim]

    def __call__(self, past_key_values):
        if past_key_values is None:
            return None
        seq_len = past_key_values[0][0].size(self.k_seq_dim)
        if seq_len <= self.cache_size:
            return past_key_values

        return [
            [
                torch.cat(
                    [
                        self.k_slice(k, 0, self.start_size),
                        self.k_slice(k, seq_len - self.recent_size, seq_len),
                    ],
                    dim=self.k_seq_dim,
                ),

                torch.cat(
                    [
                        self.v_slice(v, 0, self.start_size),
                        self.v_slice(v, seq_len - self.recent_size, seq_len),
                    ],
                    dim=self.v_seq_dim,
                ),

            ] for k, v in past_key_values
        ]


    def evict_for_space(self, past_key_values, num_coming):
        if past_key_values is None:
            return None
        seq_len = past_key_values[0][0].size(self.k_seq_dim)
        if seq_len + num_coming <= self.cache_size:
            return past_key_values
        return [
            [
                torch.cat(
                    [
                        self.k_slice(k, 0, self.start_size),
                        self.k_slice(k, seq_len - self.recent_size + num_coming, seq_len),
                    ],
                    dim=self.k_seq_dim,
                ),

                torch.cat(
                    [
                        self.v_slice(v, 0, self.start_size),
                        self.v_slice(v, seq_len - self.recent_size + num_coming, seq_len),
                    ],
                    dim=self.v_seq_dim,
                ),
            ] for k, v in past_key_values
        ]

    def evict_range(self, past_key_values, start, end):
        if past_key_values is None:
            return None
        seq_len = past_key_values[0][0].size(self.k_seq_dim)
        assert start <= end and end <= seq_len
        return [
            [
                torch.cat(
                    [self.k_slice(k, 0, start), self.k_slice(k, end, seq_len),], dim=self.k_seq_dim,
                ),

                torch.cat(
                    [self.v_slice(v, 0, start), self.v_slice(v, end, seq_len),], dim=self.v_seq_dim,
                ),
            ]
            for k, v in past_key_values
        ]

k_seq_dim = v_seq_dim = 2
kv_cache = StartRecentKVCache(
    start_size=4, recent_size=2000, k_seq_dim=k_seq_dim, v_seq_dim=v_seq_dim
)
greedy_generate(model, tokenizer, prompt, kv_cache=kv_cache, max_new_tokens=1000)