# Notebook to test speculative decoding functions work with a finetuned gemma-2b for faster inference on text
### Procedure in general for Speculative Decoding:
    #1. Finetune Gemma-2b on UnifiedQA Dataset
    #2. Use Gemma-2b as the draft model for Gemma-7b to see number of tokens w/ latency
    #3. To measure throughput: https://towardsdatascience.com/deploying-large-language-models-vllm-and-quantizationstep-by-step-guide-on-how-to-accelerate-becfe17396a2


In [None]:
from google.colab import drive
drive.mount('/content/drive')

Mounted at /content/drive


In [None]:
%load_ext autoreload
%autoreload 2
%cd drive/MyDrive/Efficient LLM Benchmarks

The autoreload extension is already loaded. To reload it, use:
  %reload_ext autoreload
[Errno 2] No such file or directory: 'drive/MyDrive/Efficient LLM Benchmarks'
/content/drive/.shortcut-targets-by-id/1ttMDC1vO7-d3JL9f_kDjpBd1LpxUxMSl/Efficient LLM Benchmarks


In [None]:
!pip install datasets
!pip install accelerate
!pip install bitsandbytes
!pip install peft
!pip install trl
!pip install transformers
!pip install evaluate
!pip install rouge_score

Collecting datasets
  Downloading datasets-2.19.0-py3-none-any.whl (542 kB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m542.0/542.0 kB[0m [31m9.8 MB/s[0m eta [36m0:00:00[0m
Collecting dill<0.3.9,>=0.3.0 (from datasets)
  Downloading dill-0.3.8-py3-none-any.whl (116 kB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m116.3/116.3 kB[0m [31m9.4 MB/s[0m eta [36m0:00:00[0m
Collecting xxhash (from datasets)
  Downloading xxhash-3.4.1-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl (194 kB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m194.1/194.1 kB[0m [31m12.6 MB/s[0m eta [36m0:00:00[0m
[?25hCollecting multiprocess (from datasets)
  Downloading multiprocess-0.70.16-py310-none-any.whl (134 kB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m134.8/134.8 kB[0m [31m16.3 MB/s[0m eta [36m0:00:00[0m
Collecting huggingface-hub>=0.21.2 (from datasets)
  Downloading huggingface_hub-0.22.2-py3-none-an

In [None]:
from huggingface_hub import notebook_login
notebook_login() # use your access token here!

VBox(children=(HTML(value='<center> <img\nsrc=https://huggingface.co/front/assets/huggingface_logo-noborder.sv…

In [None]:
import os
import json
import time
import torch

import datasets
import accelerate

from transformers import AutoModelForCausalLM, AutoTokenizer, BitsAndBytesConfig, HfArgumentParser, TrainingArguments
from trl import SFTTrainer
from peft import AutoPeftModelForCausalLM

from Experiments.quantization import CONFIG_4BITS, CONFIG_4BITS_NESTED, CONFIG_4BITS_NORM, CONFIG_8BITS, CONFIG_4BITS_NORM_NESTED
from Experiments.run_utils import load_model, speculative_decoding, load_tokenized_dataset, del_model_off_gpu
from Experiments.eval_utils import strip_output_text, strip_answers

In [None]:
support_model_name = "google/gemma-2b"
# base_model_name = "google/gemma-2b"
base_model_name = "google/gemma-7b"
# base_model_name = 'meta-llama/Llama-2-7b-hf'
# base_model_name = 'mistralai/Mistral-7B-v0.1'

base_path = '/content/drive/MyDrive/Efficient LLM Benchmarks'
local_models_path = f'{base_path}/Experiments/trained_models'

base_model_test_data_map = {
    "google/gemma-2b": 'Gemma_NEW',
    "google/gemma-7b": 'Gemma_NEW',
    'meta-llama/Llama-2-7b-hf': 'Llama_NEW',
    'mistralai/Mistral-7B-v0.1': 'Mistral_NEW'
}

base_model_local_model_map = {
    "google/gemma-2b": 'gemma_2b_qlora_4bits_norm_nested_outputs/gemma_2b_qlora_4bits_norm_nested_final/',
    "google/gemma-7b": 'gemma_7b_qlora_4bits_norm_nested_outputs/gemma_7b_qlora_4bits_norm_nested_final',
    'meta-llama/Llama-2-7b-hf': 'llama2_7b_qlora_4bits_norm_nested_outputs/llama2_7b_qlora_4bits_norm_nested_final',
    'mistralai/Mistral-7B-v0.1': 'mistral_7b_qlora_4bits_norm_nested_outputs/mistral_7b_qlora_4bits_norm_nested_final'
}

local_trained_model_path = f'{local_models_path}/{base_model_local_model_map[support_model_name]}'
local_trained_model_path

'/content/drive/MyDrive/Efficient LLM Benchmarks/Experiments/trained_models/gemma_2b_qlora_4bits_norm_nested_outputs/gemma_2b_qlora_4bits_norm_nested_final/'

In [None]:
test_data = load_tokenized_dataset(os.path.join(
    f"{base_path}/UnifiedQA Data Curation/tokenized_NEW/{base_model_test_data_map[base_model_name]}",
    "test.json"))


In [None]:
assistant_model = AutoPeftModelForCausalLM.from_pretrained(
    local_trained_model_path, device_map={"": 0}, quantization_config=CONFIG_4BITS)

Loading checkpoint shards:   0%|          | 0/2 [00:00<?, ?it/s]

In [None]:
large_model, tokenizer = load_model(
    base_model=base_model_name, bnb_config=CONFIG_4BITS, on_gpu=True, use_cache=False, pretraining_tp=1)

config.json:   0%|          | 0.00/629 [00:00<?, ?B/s]

model.safetensors.index.json:   0%|          | 0.00/20.9k [00:00<?, ?B/s]

Downloading shards:   0%|          | 0/4 [00:00<?, ?it/s]

model-00001-of-00004.safetensors:   0%|          | 0.00/5.00G [00:00<?, ?B/s]

model-00002-of-00004.safetensors:   0%|          | 0.00/4.98G [00:00<?, ?B/s]

model-00003-of-00004.safetensors:   0%|          | 0.00/4.98G [00:00<?, ?B/s]

model-00004-of-00004.safetensors:   0%|          | 0.00/2.11G [00:00<?, ?B/s]

Loading checkpoint shards:   0%|          | 0/4 [00:00<?, ?it/s]

generation_config.json:   0%|          | 0.00/137 [00:00<?, ?B/s]

tokenizer_config.json:   0%|          | 0.00/33.6k [00:00<?, ?B/s]

tokenizer.model:   0%|          | 0.00/4.24M [00:00<?, ?B/s]

tokenizer.json:   0%|          | 0.00/17.5M [00:00<?, ?B/s]

special_tokens_map.json:   0%|          | 0.00/636 [00:00<?, ?B/s]

In [None]:
def spec_decod_wrapper(trained_model:AutoModelForCausalLM, tokenizer:AutoTokenizer, assistant_model:AutoModelForCausalLM, ds:Dataset, model_name:str, sample:int=1000, seed:int=42, save_path:str=''):
    def add_dataset_name_col(ds):
        original_dataset = []
        for example in ds:
            original_dataset.append(example['id'].split('-')[0])
        eval_sample = datasets.concatenate_datasets([ds, Dataset.from_dict({'original_dataset': original_dataset})], axis=1)
        return eval_sample

    if sample > 0:
        ds = ds.shuffle(seed=seed)
        sample_data = ds.select(range(sample))
        eval_sample = add_dataset_name_col(sample_data)
    else:
        eval_sample = add_dataset_name_col(ds)
    print("eval_sample generated")
    predictions, latencies = speculative_decoding(trained_model, tokenizer, assistant_model, eval_sample)
    print("predictions generated")
    # predictions = [predictions[i][len(eval_sample['questions'][i]):] for i in range(len(eval_sample['questions']))]
    predictions = [strip_output_text(s, model_name) for s in predictions]

    answers_stripped = [strip_answers(s, model_name) for s in eval_sample['answers']]


    pred_ds = Dataset.from_dict({
        'predictions': [p.lower() for p in predictions],
        'ground_truth':answers_stripped,
        'original_dataset':eval_sample['original_dataset'],
        'latencies': latencies})

    if len(save_path) > 0:

        dir = save_path.split('/')[:-1]
        print(os.path.join(*dir))
        os.makedirs(f"/{os.path.join(*dir)}", exist_ok=True)
        print(save_path)
        with open(save_path, "w") as f:
            json.dump([pred_ds['predictions'], pred_ds['ground_truth'], pred_ds['original_dataset'], pred_ds['latencies']], f)

    return pred_ds



In [None]:
pred_ds = spec_decod_wrapper(
    large_model, tokenizer, assistant_model, test_data,
    base_model_name, sample=10,
    save_path=f'{base_path}/Experiments/predictions/spec_decod/predictions_spec_decod_7b_gemma_using_2b_qlora_gemma.json')

Flattening the indices:   0%|          | 0/10 [00:00<?, ? examples/s]

Flattening the indices:   0%|          | 0/10 [00:00<?, ? examples/s]

eval_sample generated
cuda is available
predictions generated
content/drive/MyDrive/Efficient LLM Benchmarks/Experiments/predictions/spec_decod
/content/drive/MyDrive/Efficient LLM Benchmarks/Experiments/predictions/spec_decod/predictions_spec_decod_7b_gemma_using_2b_qlora_gemma.json


In [None]:
# Batched predictions
BATCH_SIZE = 50
for i in range(0, 500, BATCH_SIZE):
  start_time = time.time()
  print(f"Start time batch {i}: {start_time}")
  s = test_data.select(range(i, i+BATCH_SIZE))
  pred_ds = spec_decod_wrapper(
    large_model, tokenizer, assistant_model, s,
    base_model_name, sample=BATCH_SIZE,
    save_path=f'{base_path}/Experiments/predictions/spec_decod/gemma_7b/assistant_qlora_gemma_2b/predictions_batch_{i}.json')
  del s
  del pred_ds
  print(f"Total time batch {i}: {time.time() - start_time}")

Start time batch 0: 1714490127.7038753


Flattening the indices:   0%|          | 0/50 [00:00<?, ? examples/s]

Flattening the indices:   0%|          | 0/50 [00:00<?, ? examples/s]

eval_sample generated
cuda is available
predictions generated
content/drive/MyDrive/Efficient LLM Benchmarks/Experiments/predictions/spec_decod/gemma_7b/assistant_qlora_gemma_2b
/content/drive/MyDrive/Efficient LLM Benchmarks/Experiments/predictions/spec_decod/gemma_7b/assistant_qlora_gemma_2b/predictions_batch_0.json
Total time batch 0: 1174.106873035431
Start time batch 50: 1714491301.8112874


Flattening the indices:   0%|          | 0/50 [00:00<?, ? examples/s]

Flattening the indices:   0%|          | 0/50 [00:00<?, ? examples/s]

eval_sample generated
cuda is available
predictions generated
content/drive/MyDrive/Efficient LLM Benchmarks/Experiments/predictions/spec_decod/gemma_7b/assistant_qlora_gemma_2b
/content/drive/MyDrive/Efficient LLM Benchmarks/Experiments/predictions/spec_decod/gemma_7b/assistant_qlora_gemma_2b/predictions_batch_50.json
Total time batch 50: 1088.7893595695496
Start time batch 100: 1714492390.6007335


Flattening the indices:   0%|          | 0/50 [00:00<?, ? examples/s]

Flattening the indices:   0%|          | 0/50 [00:00<?, ? examples/s]

eval_sample generated
cuda is available
predictions generated
content/drive/MyDrive/Efficient LLM Benchmarks/Experiments/predictions/spec_decod/gemma_7b/assistant_qlora_gemma_2b
/content/drive/MyDrive/Efficient LLM Benchmarks/Experiments/predictions/spec_decod/gemma_7b/assistant_qlora_gemma_2b/predictions_batch_100.json
Total time batch 100: 1066.9574563503265
Start time batch 150: 1714493457.5587316


Flattening the indices:   0%|          | 0/50 [00:00<?, ? examples/s]

Flattening the indices:   0%|          | 0/50 [00:00<?, ? examples/s]

eval_sample generated
cuda is available
predictions generated
content/drive/MyDrive/Efficient LLM Benchmarks/Experiments/predictions/spec_decod/gemma_7b/assistant_qlora_gemma_2b
/content/drive/MyDrive/Efficient LLM Benchmarks/Experiments/predictions/spec_decod/gemma_7b/assistant_qlora_gemma_2b/predictions_batch_150.json
Total time batch 150: 1049.138526916504
Start time batch 200: 1714494506.697328


Flattening the indices:   0%|          | 0/50 [00:00<?, ? examples/s]

Flattening the indices:   0%|          | 0/50 [00:00<?, ? examples/s]

eval_sample generated
cuda is available
predictions generated
content/drive/MyDrive/Efficient LLM Benchmarks/Experiments/predictions/spec_decod/gemma_7b/assistant_qlora_gemma_2b
/content/drive/MyDrive/Efficient LLM Benchmarks/Experiments/predictions/spec_decod/gemma_7b/assistant_qlora_gemma_2b/predictions_batch_200.json
Total time batch 200: 985.6807007789612
Start time batch 250: 1714495492.3781047


Flattening the indices:   0%|          | 0/50 [00:00<?, ? examples/s]

Flattening the indices:   0%|          | 0/50 [00:00<?, ? examples/s]

eval_sample generated
cuda is available
predictions generated
content/drive/MyDrive/Efficient LLM Benchmarks/Experiments/predictions/spec_decod/gemma_7b/assistant_qlora_gemma_2b
/content/drive/MyDrive/Efficient LLM Benchmarks/Experiments/predictions/spec_decod/gemma_7b/assistant_qlora_gemma_2b/predictions_batch_250.json
Total time batch 250: 1083.4507684707642
Start time batch 300: 1714496575.8289382


Flattening the indices:   0%|          | 0/50 [00:00<?, ? examples/s]

Flattening the indices:   0%|          | 0/50 [00:00<?, ? examples/s]

eval_sample generated
cuda is available
predictions generated
content/drive/MyDrive/Efficient LLM Benchmarks/Experiments/predictions/spec_decod/gemma_7b/assistant_qlora_gemma_2b
/content/drive/MyDrive/Efficient LLM Benchmarks/Experiments/predictions/spec_decod/gemma_7b/assistant_qlora_gemma_2b/predictions_batch_300.json
Total time batch 300: 1162.416612625122
Start time batch 350: 1714497738.2456183


Flattening the indices:   0%|          | 0/50 [00:00<?, ? examples/s]

Flattening the indices:   0%|          | 0/50 [00:00<?, ? examples/s]

eval_sample generated
cuda is available
predictions generated
content/drive/MyDrive/Efficient LLM Benchmarks/Experiments/predictions/spec_decod/gemma_7b/assistant_qlora_gemma_2b
/content/drive/MyDrive/Efficient LLM Benchmarks/Experiments/predictions/spec_decod/gemma_7b/assistant_qlora_gemma_2b/predictions_batch_350.json
Total time batch 350: 997.2684001922607
Start time batch 400: 1714498735.5145624


Flattening the indices:   0%|          | 0/50 [00:00<?, ? examples/s]

Flattening the indices:   0%|          | 0/50 [00:00<?, ? examples/s]

eval_sample generated
cuda is available
predictions generated
content/drive/MyDrive/Efficient LLM Benchmarks/Experiments/predictions/spec_decod/gemma_7b/assistant_qlora_gemma_2b
/content/drive/MyDrive/Efficient LLM Benchmarks/Experiments/predictions/spec_decod/gemma_7b/assistant_qlora_gemma_2b/predictions_batch_400.json
Total time batch 400: 1251.6020686626434
Start time batch 450: 1714499987.117192


Flattening the indices:   0%|          | 0/50 [00:00<?, ? examples/s]

Flattening the indices:   0%|          | 0/50 [00:00<?, ? examples/s]

eval_sample generated
cuda is available
predictions generated
content/drive/MyDrive/Efficient LLM Benchmarks/Experiments/predictions/spec_decod/gemma_7b/assistant_qlora_gemma_2b
/content/drive/MyDrive/Efficient LLM Benchmarks/Experiments/predictions/spec_decod/gemma_7b/assistant_qlora_gemma_2b/predictions_batch_450.json
Total time batch 450: 1131.7266387939453


In [None]:
for i in range(0, 500, BATCH_SIZE):
  start_time = time.time()
  print(f"Start time batch {i}: {start_time}")
  s = test_data.select(range(i, i+BATCH_SIZE))
  pred_ds = spec_decod_wrapper(
    large_model, tokenizer, assistant_model, s,
    base_model_name, sample=BATCH_SIZE,
    save_path=f'{base_path}/Experiments/predictions/spec_decod/gemma_7b/assistant_qlora_gemma_2b/predictions_batch_{i}.json')