In [None]:
# Notebook to test speculative decoding functions work with a finetuned gemma-2b for faster inference on text
# Procedure in general for Speculative Decoding:
    #1. Finetune Gemma-2b on UnifiedQA Dataset 
    #2. Use Gemma-2b as the draft model for Gemma-7b to see number of tokens w/ latency
    #3. To measure throughput: https://towardsdatascience.com/deploying-large-language-models-vllm-and-quantizationstep-by-step-guide-on-how-to-accelerate-becfe17396a2

import os
import json
import torch
import logging 
import pandas as pd

from collections import defaultdict
from datasets import Dataset
import accelerate
import bitsandbytes

from peft import LoraConfig, LoraModel
from transformers import AutoModelForCausalLM, AutoTokenizer, BitsAndBytesConfig, HfArgumentParser, TrainingArguments, pipeline
from trl import SFTTrainer
from quantization import CONFIG_4BITS, CONFIG_4BITS_NESTED, CONFIG_4BITS_NORM, CONFIG_8BITS, CONFIG_4BITS_NORM_NESTED
from run_utils import *


In [None]:
gemma_train_dataset = load_tokenized_dataset("/home/andrusha/Desktop/DL Research/Efficient-LLM-Benchmark/UnifiedQA Data Curation/tokenized/Gemma/train.json")
gemma_dev_dataset = load_tokenized_dataset("/home/andrusha/Desktop/DL Research/Efficient-LLM-Benchmark/UnifiedQA Data Curation/tokenized/Gemma/dev.json")
gemma_test_dataset = load_tokenized_dataset("/home/andrusha/Desktop/DL Research/Efficient-LLM-Benchmark/UnifiedQA Data Curation/tokenized/Gemma/test.json")


In [None]:
gemma_model, gemma_tokenizer = load_model(base_model="google/gemma-7b", bnb_config=CONFIG_4BITS_NORM_NESTED, on_gpu=True, use_cache=False, pretraining_tp=1) 

gemma_model_2b, gemma_tokenizer = load_model(base_model="google/gemma-2b", bnb_config=CONFIG_4BITS_NORM_NESTED, on_gpu=False, use_cache=False, pretraining_tp=1) 




In [None]:
# https://towardsdatascience.com/deploying-large-language-models-vllm-and-quantizationstep-by-step-guide-on-how-to-accelerate-becfe17396a2
inputs = gemma_tokenizer("Generate a python code that accepts a list of numbers and returns the sum.", return_tensors='pt', return_attention_mask=False)
speculative_decoding(gemma_model, gemma_model_2b, inputs, gemma_tokenizer)

In [None]:
throughput(gemma_model, gemma_model_2b, gemma_tokenizer, inputs, 200, .1)

In [None]:
del_model_of_gpu(gemma_model)
del_model_of_gpu(gemma_model_2b)