In [1]:
import os
os.environ['CUDA_VISIBLE_DEVICES'] = '0'

import torch
import transformers
import sys
sys.path.append('../')

  from .autonotebook import tqdm as notebook_tqdm


In [2]:
from sparsedoping.model import LlamaSparseModelForCausalLM

from sparsedoping.model import LlamaSparseModelForCausalLM, LlamaSparseConfig

from transformers import AutoConfig, AutoModelForCausalLM

AutoConfig.register("llama_sparse", LlamaSparseConfig)
AutoModelForCausalLM.register(LlamaSparseConfig, LlamaSparseModelForCausalLM)

tokenizer = transformers.AutoTokenizer.from_pretrained("meta-llama/Llama-2-7B-hf")
model = LlamaSparseModelForCausalLM.from_pretrained(
    "meta-llama/Llama-2-7B-hf", 
    torch_dtype=torch.float16,
    low_cpu_mem_usage=True,
    attn_implementation="flash_attention_2",
    histogram_path="/home/jamesliu/models/Llama-2-7B/histograms", 
    apply_prefill=False,
).to("cuda")

greedy_sparsity_path = "/home/jamesliu/models/Llama-2-7B/lookup"



You are using a model of type llama to instantiate a model of type llama_sparse. This is not supported for all configurations of models and can yield errors.
You are attempting to use Flash Attention 2.0 with a model not initialized on GPU. Make sure to move the model to GPU after initializing it on CPU with `model.to('cuda')`.
Loading checkpoint shards: 100%|██████████| 2/2 [00:00<00:00,  9.30it/s]


In [3]:
def decode_one_tokens(model, cur_token, input_pos, cache_position, past_key_values):
    logits = model(
        cur_token,
        position_ids=input_pos,
        cache_position=cache_position,
        past_key_values=past_key_values,
        return_dict=False,
        use_cache=True
    )[0]
    new_token = torch.argmax(logits[:, -1], dim=-1)[:, None]
    return new_token

def graph_wrapper(fn, *init_args, **init_kwargs):
    s = torch.cuda.Stream(device="cuda")
    s.wait_stream(torch.cuda.current_stream())
    with torch.cuda.stream(s):
        fn(*init_args, **init_kwargs)

    torch.cuda.current_stream().wait_stream(s)
    graph = torch.cuda.CUDAGraph()

    with torch.cuda.graph(graph, stream=s):
        static_output = fn(*init_args, **init_kwargs)

    static_args = init_args
    static_kwargs = init_kwargs
    

    def replay(*args, **kwargs):
        for i in range(len(args)):
            if isinstance(args[i], torch.Tensor):
                static_args[i].copy_(args[i])
        for kw in kwargs:
            if isinstance(kwargs[kw], torch.Tensor):
                static_kwargs[kw].copy_(kwargs[kw])

        graph.replay()
        return static_output
    
    return replay

In [37]:
from utils import get_layer_greedy_sparsities

from data import get_dataset
from transformers import StaticCache
import torch
from tqdm import tqdm
import os

import torch.cuda.nvtx as nvtx


# assumes model mode is default or turbo
def eval_speed(model, tokenizer, device, sparsities, dataset, debug=False):
    nvtx.range_push("eval_speed")
    model.eval()
    acc_0_list = []
    nlls = []


    text = ""
    for sample in dataset:
        text += sample["text"] + "\n\n"

    
    prefill_len = 128
    max_len = prefill_len + 512


    all_encodings = tokenizer(text, return_tensors="pt", max_length=max_len).to(device)
    prefill_encodings = tokenizer(text, return_tensors="pt", max_length=prefill_len).to(device)

    nvtx.range_push("model_setup")
    model.reset_sparsities()

    model.set_sparsities(sparsities)

    # print(vars(encodings))

    batch_size = 1
    with torch.no_grad():
        nvtx.range_push("prefill")
        past_key_values = StaticCache(
            config=model.config, max_batch_size=1, max_cache_len=4096, device=model.device, dtype=model.dtype
        )
        cache_position = torch.arange(prefill_len, device=model.device)
        generated_ids = torch.zeros(
            batch_size, max_len, dtype=torch.int, device=model.device
        ) # bsz, num tokens total

        generated_ids[:, cache_position] = prefill_encodings["input_ids"].to(model.device).to(torch.int)


        logits = model(
            **prefill_encodings,
            cache_position=cache_position,
            past_key_values=past_key_values,
            return_dict=False,
            use_cache=True
        )[0]

        # potentially do graph capture
        next_token = torch.argmax(logits[:, -1], dim=-1)[:, None]
        generated_ids[:, prefill_len] = next_token[:, 0]
        cache_position = torch.tensor([prefill_len + 1], device=model.device)

        torch.cuda.synchronize()

        wrapped_decode_one_tokens = graph_wrapper(
            decode_one_tokens,
            model,
            next_token,
            None,
            cache_position,
            past_key_values
        )

        torch.cuda.synchronize()

        from time import time
        generated_tokens = 0

        torch.cuda.synchronize()
        start_time = time()

        # tok_time = time()
        nvtx.range_push("token_generation_loop")
        for i in tqdm(range(0, max_len-prefill_len-1)):
            nvtx.range_push(f"token_{i}")
            # next_token = wrapped_decode_one_tokens(model, next_token.clone(), None, cache_position, past_key_values)
            # generated_ids[:, cache_position] = all_encodings.input_ids[:, cache_position].int()
            with torch.backends.cuda.sdp_kernel(enable_flash=False, enable_mem_efficient=False, enable_math=True):
                # _ = decode_one_tokens(model, next_token.clone(), None, cache_position, past_key_values)
                next_token = wrapped_decode_one_tokens(model, next_token.clone(), None, cache_position, past_key_values)

                # generated_ids[:, cache_position] = next_token.int()
                generated_ids[:, cache_position] = all_encodings.input_ids[:, cache_position].int()
            cache_position += 1
            generated_tokens += 1

            next_tok_time = time()
            # print(f"{next_tok_time-tok_time}")
            # tok_time = next_tok_time
        
        torch.cuda.synchronize()
        end_time = time()
        tokens_per_second = generated_tokens / (end_time - start_time)
        return tokens_per_second, 1/tokens_per_second

In [35]:
from data import get_dataset
from utils import get_layer_greedy_sparsities
dataset = get_dataset(
    "wikitext",
    subset="wikitext-2-raw-v1",
    split="train",
    size=100,
)

model.convert_column_mode()

sparsity_level = 0.5
projs = ['up', 'gate','down','q','k','v','o']
sparsities = {
    proj: [sparsity_level]*len(model.model.layers) for proj in projs
}

# sparsities = get_layer_greedy_sparsities([sparsity_level]*len(model.model.layers), greedy_sparsity_path)

# model.set_sparsity_mode("default")
model.set_sparsity_mode("turbo")

eval_speed(model, tokenizer, "cuda", sparsities, dataset)

100%|██████████| 895/895 [00:08<00:00, 110.81it/s] 


(96.70615449744425, 0.010340603503435018)

In [4]:
from data import get_dataset
dataset = get_dataset(
    "wikitext",
    subset="wikitext-2-raw-v1",
    split="train",
    size=100,
)

model.set_sparsity_mode("turbo")
sparsity_level = 0
projs = ['up', 'gate','down','q','k','v','o']
sparsities = {
    proj: [sparsity_level]*len(model.model.layers) for proj in projs
}

model.set_sparsities(sparsities)
# model.set_sparsity_mode("dev")

from eval_ppl import eval_ppl
eval_ppl(model, tokenizer, "cuda", dataset)

100it [00:01, 85.75it/s]


6.059079046405714