Evaluations with zero-shot, 5-shot, full fine-tune, merged LoRA, merged DoRA.

In [1]:
import torch
from pathlib import Path
import os, json
from safetensors.torch import save_file
import copy
from tqdm import tqdm
import safetensors
import safetensors.torch
from glob import glob
from transformers import AutoConfig, AutoTokenizer, AutoModelForCausalLM
from transformers.utils import hub, SAFE_WEIGHTS_NAME, SAFE_WEIGHTS_INDEX_NAME
from hqq.core.quantize import HQQLinear, HQQBackend, BaseQuantizeConfig, Quantizer

from hqq.core.quantize import HQQLinear, BaseQuantizeConfig, Quantizer
from hqq.backends.torchao import patch_hqq_to_aoint4
from fastcore.script import *

from accelerate import init_empty_weights

### Compression Rates

In [35]:
model_name = "meta-llama/Meta-Llama-3.1-405B-Instruct"
cfg = AutoConfig.from_pretrained(model_name)
with init_empty_weights():
	model = AutoModelForCausalLM.from_config(cfg)

In [36]:
def compute_quantized_layer_params(lora_rank=256, groupsize_4bit=128, groupsize_2bit=32, with_dora=True, mlp_in_4bit=False, dora_2bit=False):
	"""
	Total param counts for a quantized decoder layer.
	"""
	
	# Compression rates compared to bf16
	COMPRESSION_RATE_4BIT = 4
	COMPRESSION_RATE_2BIT = 8

	total_params = 0
	for n,p in model.model.layers[0].named_parameters():
		if "self_attn" in n:
			num_params = p.numel() / COMPRESSION_RATE_4BIT 
			num_params += 2 * (p.numel() / groupsize_4bit) # qweight + zeros + scales
			if with_dora: 
				num_params += p.size(0) * lora_rank + p.size(1) * (lora_rank + 1) # dora (lora + magnitude)
		elif "mlp" in n:
			if mlp_in_4bit:
				num_params = p.numel() / COMPRESSION_RATE_4BIT  # qweight
				num_params += 2 * (p.numel() / groupsize_4bit) # zeros + scales
			else:
				num_params = p.numel() / COMPRESSION_RATE_2BIT # qweight
				num_params += 2 * (p.numel() / groupsize_2bit) # zeros + scales
			if with_dora or dora_2bit: 
				num_params += p.size(0) * lora_rank + p.size(1) * (lora_rank + 1) # dora (lora + magnitude)
		else:
			num_params = p.numel()
		total_params += num_params
	return total_params

In [37]:
def compute_compression_ratio(block_influence_ratio = 1.0, **kwargs):
	layer_params_count_4_2bit = compute_quantized_layer_params(**kwargs) # mixed bit param counts
	layer_params_count_4bit   = compute_quantized_layer_params(mlp_in_4bit=True, **kwargs) # 4bit param counts
	num_decoder_layers = len(model.model.layers)
	num_block_influence = int(num_decoder_layers*block_influence_ratio)
	quantized_total_params = ((num_decoder_layers-num_block_influence) * layer_params_count_4_2bit + num_block_influence * layer_params_count_4bit)
	orig_total_params = sum(p.numel() for n,p in model.model.layers.named_parameters())
	return quantized_total_params / orig_total_params

In [38]:
# 4bit (128) HQQ (without DoRA) -> %0.27
compute_compression_ratio(1.0, with_dora=False)

0.2656325490075143

In [39]:
# 4bit (128) / 2bit (32), lora rank=256, block influence ratio=0.2 -> compression % 0.26
# 4bit (128), lora rank=256 -> compression % 0.31

compute_compression_ratio(0.2),compute_compression_ratio(1.0)

(0.23909828474311776, 0.2905153627121432)

In [40]:
# 4bit (128) / 2bit (32), lora rank=128, block influence ratio=0.2 -> compression % 0.24
# 4bit (128), lora rank=128 -> compression % 0.29

compute_compression_ratio(0.2, lora_rank=128),compute_compression_ratio(1.0, lora_rank=128)

(0.22668064923361436, 0.27809772720263976)

In [41]:
# 4bit (128) / 2bit (32), lora rank=64, block influence ratio=0.2 -> compression % 0.23
# 4bit (128), lora rank=64 -> compression % 0.28

compute_compression_ratio(0.2, lora_rank=64),compute_compression_ratio(1.0, lora_rank=64)

(0.22047183147886265, 0.27188890944788807)

In [42]:
# 4bit (128), 2bit (64), lora rank=64, block influence ratio=0.2 -> compression % 0.21
compute_compression_ratio(0.2, lora_rank=64, groupsize_2bit=64)

0.1999050002912525

In [90]:
# llama-3.1-8b
# 4bit (128) HQQ (without DoRA) -> %0.27

# 4bit (128), lora rank=256 -> compression % 0.36
# 4bit (128) / 2bit (32), lora rank=256, block influence ratio=0.2 -> compression % 0.31

# 4bit (128), lora rank=128 -> compression % 0.31
# 4bit (128) / 2bit (32), lora rank=128, block influence ratio=0.2 -> compression % 0.26

# 4bit (128), lora rank=64 -> compression % 0.29
# 4bit (128) / 2bit (32), lora rank=64, block influence ratio=0.2 -> compression % 0.24

# 4bit (128), 2bit (64), lora rank=64, block influence ratio=0.2 -> compression % 0.22


# llama-3.1-70b
# 4bit (128) HQQ (without DoRA) -> %0.27
# 4bit (128), lora rank=256 -> compression % 0.31
# 4bit (128) / 2bit (32), lora rank=256, block influence ratio=0.2 -> compression % 0.26
# 4bit (128), lora rank=128 -> compression % 0.29
# 4bit (128) / 2bit (32), lora rank=128, block influence ratio=0.2 -> compression % 0.24
# 4bit (128), lora rank=64 -> compression % 0.28
# 4bit (128) / 2bit (32), lora rank=64, block influence ratio=0.2 -> compression % 0.23
# 4bit (128), 2bit (64), lora rank=64, block influence ratio=0.2 -> compression % 0.21


# Another idea: Only apply lora to 2 bit quantized layers and freeze the 4bit layers.




In [91]:
MODELS_DIR = Path("/workspace/models")

In [92]:
PATH = "llama-3-1-8b-instruct-dora-4-2bit-lora_rank-256-base_lr-5e-5-lr_div_factor-10-train_layernorms-true-block-influence-adj-20pct/step_125/vllm_bitblas"

In [93]:
weights = safetensors.torch.load_file(MODELS_DIR / PATH / "model-00001-of-00004.safetensors")

In [94]:
for k,p in weights.items():
	if ("layers.0." in k) or ("layers.3." in k):
 		print(k, p.shape)

model.layers.0.input_layernorm.weight torch.Size([4096])
model.layers.0.mlp.down_proj.lora_A torch.Size([256, 14336])
model.layers.0.mlp.down_proj.lora_B torch.Size([4096, 256])
model.layers.0.mlp.down_proj.qweight torch.Size([4096, 7168])
model.layers.0.mlp.down_proj.rescale torch.Size([4096])
model.layers.0.mlp.down_proj.scales torch.Size([4096, 112])
model.layers.0.mlp.down_proj.zeros torch.Size([4096, 112])
model.layers.0.mlp.gate_proj.lora_A torch.Size([256, 4096])
model.layers.0.mlp.gate_proj.lora_B torch.Size([14336, 256])
model.layers.0.mlp.gate_proj.qweight torch.Size([14336, 2048])
model.layers.0.mlp.gate_proj.rescale torch.Size([14336])
model.layers.0.mlp.gate_proj.scales torch.Size([14336, 32])
model.layers.0.mlp.gate_proj.zeros torch.Size([14336, 32])
model.layers.0.mlp.up_proj.lora_A torch.Size([256, 4096])
model.layers.0.mlp.up_proj.lora_B torch.Size([14336, 256])
model.layers.0.mlp.up_proj.qweight torch.Size([14336, 2048])
model.layers.0.mlp.up_proj.rescale torch.Size([

In [95]:
config = json.load(open(MODELS_DIR / PATH / "config.json"))

In [97]:
# optim_states = torch.load(MODELS_DIR / PATH / "optimizer.bin")

In [98]:
# len(optim_states['param_groups'])

In [99]:
# for k,v in optim_states['state'].items():
#     if "layers.0" in k:
# 	    print(k, optim_states['state'][k]['exp_avg'].shape, optim_states['state'][k]['exp_avg_sq'].shape)

In [100]:
# k, optim_states['state'][k].keys(), optim_states['state'][k]['exp_avg'].shape

In [101]:
# with init_empty_weights():
# 	model = AutoModelForCausalLM.from_pretrained("meta-llama/Meta-Llama-3.1-8B-Instruct")

In [102]:
# def print_modules(model, prefix="model"):
#     for name, module in model.named_children():
#         if len(list(module.children())) > 0:
#             print_modules(module, prefix + "." + name)
#         print(prefix + "." + name)

In [103]:
# "18,19,20,26,27,28".split(",")

In [104]:
# print_modules(model)

### Evaluate

In [2]:
import re
from datasets import load_dataset
from fastcore.parallel import parallel
from vllm import LLM, SamplingParams

In [3]:
def extract_last_number_or_ratio(s):
    # Find all sequences of digits, possibly with leading currency symbols, decimal points, and ratios
    patterns = re.findall(r'[\$€£]?\d+(?:\.\d+)?(?:\:\d+(?:\.\d+)?)?', s)
    
    # Return the last pattern found, or None if there are no matches
    if patterns:
        return patterns[-1]
    else:
        return None

In [4]:
def exact_match_score(preds, labels):
    return sum(p==g for p,g in zip(preds, labels))/len(preds)

In [5]:
dataset = load_dataset("microsoft/orca-math-word-problems-200k")['train'].shuffle(seed=42)
dataset = dataset.select(range(len(dataset)-5000,len(dataset)))
short_answers_gt = parallel(extract_last_number_or_ratio, dataset['answer'], progress=True)

In [6]:
valid_dataset = dataset.select(range(500))

In [7]:
inputs = [f"###Question:\n{question}\n###Answer:\n" for question in valid_dataset['question']]

In [8]:
labels = short_answers_gt[:500]

In [9]:
len(inputs), len(labels)

(500, 500)

In [10]:
NUM_GPUS = torch.cuda.device_count(); NUM_GPUS

1

In [11]:
TOKENIZER = "meta-llama/Meta-Llama-3.1-8B-Instruct"

In [12]:
tokenizer = AutoTokenizer.from_pretrained(TOKENIZER)

In [13]:
def convert_to_chat_input(question):
    messages = [
        {"role": "system", "content": "You are an AI assistant that excels in solving math problems."},
        {"role": "user", "content": question},
    ]
    return tokenizer.apply_chat_template(messages, tokenize=False, add_generation_prompt=True)

In [14]:
chat_inputs = [convert_to_chat_input(question) for question in valid_dataset['question']]

In [15]:
len(chat_inputs)

500

#### FINETUNED 

In [16]:
model_dir = "/workspace/models/"

In [17]:
# TODO: Don't cast to float16 for bitblas

In [18]:
# NOTE: 4/2 bit models become very repetitive


# bf16 model # 0.60

# 4bit HQQ (skip_dora_all)
# MODEL_NAME = "llama-3-1-8b-instruct-hqq-dora-4bit/step_250/vllm_tinygemm" # 0.578

# 4bit HQQ+DORA
# MODEL_NAME = "llama-3-1-8b-instruct-hqq-dora-4bit/step_250/vllm_tinygemm" # 0.593
# MODEL_NAME = "llama-3-1-8b-instruct-hqq-dora-4bit/step_1000/vllm_tinygemm" # 0.610

# 4/2bit HQQ+DORA
# MODEL_NAME = "llama-3-1-8b-instruct-hqq-dora-4-2bit/step_250/vllm_bitblas" # 0.290
# MODEL_NAME = "llama-3-1-8b-instruct-hqq-dora-4-2bit/step_1000/vllm_bitblas" # 0.380

# 4/2bit HQQ+DORA (merged)
# MODEL_NAME = "llama-3-1-8b-instruct-hqq-dora-4-2bit/step_250/merged" # 0.242
# MODEL_NAME = "llama-3-1-8b-instruct-hqq-dora-4-2bit/step_1000/merged" # 0.350

# 4bit HQQ+DORA+LN
# MODEL_NAME = "llama-3-1-8b-instruct-hqq-dora-4bit-ln/step_250/vllm_tinygemm" # 0.562
# MODEL_NAME = "llama-3-1-8b-instruct-hqq-dora-4bit-ln/step_1000/vllm_tinygemm" # 0.585

# 4bit HQQ+LN
# MODEL_NAME = "llama-3-1-8b-instruct-hqq-4bit-ln/step_250/vllm_tinygemm" # 0.605
# MODEL_NAME = "llama-3-1-8b-instruct-hqq-4bit-ln/step_1000/vllm_tinygemm" # 0.55

# 4/2bit HQQ+DORA+LN
# MODEL_NAME = "llama-3-1-8b-instruct-hqq-dora-4-2bit-ln/step_250/vllm_bitblas" # 0.270
# MODEL_NAME = "llama-3-1-8b-instruct-hqq-dora-4-2bit-ln/step_1000/vllm_bitblas" # 0.42

# 4/2bit HQQ+DORA+LN (merged)
# MODEL_NAME = "llama-3-1-8b-instruct-hqq-dora-4-2bit-ln/step_250/merged" # 0.270
# MODEL_NAME = "llama-3-1-8b-instruct-hqq-dora-4-2bit-ln/step_1000/merged" # 0.343

# 4(HQQ)bit 2(HQQ+DORA)bit
# MODEL_NAME = "llama-3-1-8b-instruct-hqq-4-dora-2bit/step_250/vllm_bitblas" # 0.257
# MODEL_NAME = "llama-3-1-8b-instruct-hqq-4-dora-2bit/step_1000/vllm_bitblas" # 0.330

# 4(HQQ)bit 2(HQQ+DORA)bit+LN
# MODEL_NAME = "llama-3-1-8b-instruct-hqq-4-dora-2bit-ln/step_250/vllm_bitblas" # 0.265
# MODEL_NAME = "llama-3-1-8b-instruct-hqq-4-dora-2bit-ln/step_1000/vllm_bitblas" # 0.375

# 4(HQQ)bit 2(HQQ)bit+LN
# MODEL_NAME = "llama-3-1-8b-instruct-hqq-4-hqq-2bit-ln/step_250/vllm_bitblas" # 0.02
# MODEL_NAME = "llama-3-1-8b-instruct-hqq-4-hqq-2bit-ln/step_1000/vllm_bitblas" # 0.03


# Model generations become more repetitive with more training steps at temp=0.0.
# Frequency penalty helps with this, but needs more investigation.

# MODEL_NAME = "llama-3-1-8b-instruct-dora-4-2bit-lora_rank-64-base_lr-1e-4-lr_div_factor-10-train_layernorms-true/step_125/vllm_bitblas" # 0.22
# MODEL_NAME = "llama-3-1-8b-instruct-dora-4-2bit-lora_rank-64-base_lr-1e-4-lr_div_factor-10-train_layernorms-true/step_250/vllm_bitblas" # 0.26

# MODEL_NAME = "llama-3-1-8b-instruct-dora-4-2bit-lora_rank-64-base_lr-1e-4-lr_div_factor-3-train_layernorms-true/step_125/vllm_bitblas" # 0.24
# MODEL_NAME = "llama-3-1-8b-instruct-dora-4-2bit-lora_rank-64-base_lr-1e-4-lr_div_factor-3-train_layernorms-true/step_250/vllm_bitblas" # 0.31

# MODEL_NAME = "llama-3-1-8b-instruct-dora-4-2bit-lora_rank-64-base_lr-1e-4-lr_div_factor-1-train_layernorms-true/step_125/vllm_bitblas" # 0.22
# MODEL_NAME = "llama-3-1-8b-instruct-dora-4-2bit-lora_rank-64-base_lr-1e-4-lr_div_factor-1-train_layernorms-true/step_250/vllm_bitblas" # 0.29

# MODEL_NAME = "llama-3-1-8b-instruct-dora-4-2bit-lora_rank-64-base_lr-5e-5-lr_div_factor-10-train_layernorms-true/step_125/vllm_bitblas" # 0.30
# MODEL_NAME = "llama-3-1-8b-instruct-dora-4-2bit-lora_rank-64-base_lr-5e-5-lr_div_factor-10-train_layernorms-true/step_250/vllm_bitblas" # 0.35

# MODEL_NAME = "llama-3-1-8b-instruct-dora-4-2bit-lora_rank-64-base_lr-5e-5-lr_div_factor-3-train_layernorms-true/step_125/vllm_bitblas" # 0.26
# MODEL_NAME = "llama-3-1-8b-instruct-dora-4-2bit-lora_rank-64-base_lr-5e-5-lr_div_factor-3-train_layernorms-true/step_250/vllm_bitblas" # 0.35

# MODEL_NAME = "llama-3-1-8b-instruct-dora-4-2bit-lora_rank-64-base_lr-5e-5-lr_div_factor-1-train_layernorms-true/step_125/vllm_bitblas" # 0.27
# MODEL_NAME = "llama-3-1-8b-instruct-dora-4-2bit-lora_rank-64-base_lr-5e-5-lr_div_factor-1-train_layernorms-true/step_250/vllm_bitblas" # 0.36

# MODEL_NAME = "llama-3-1-8b-instruct-dora-4-2bit-lora_rank-64-base_lr-2e-5-lr_div_factor-10-train_layernorms-true/step_125/vllm_bitblas" # 0.20
# MODEL_NAME = "llama-3-1-8b-instruct-dora-4-2bit-lora_rank-64-base_lr-2e-5-lr_div_factor-10-train_layernorms-true/step_250/vllm_bitblas" # 0.32

# MODEL_NAME = "llama-3-1-8b-instruct-dora-4-2bit-lora_rank-64-base_lr-2e-5-lr_div_factor-3-train_layernorms-true/step_125/vllm_bitblas" # 0.27
# MODEL_NAME = "llama-3-1-8b-instruct-dora-4-2bit-lora_rank-64-base_lr-2e-5-lr_div_factor-3-train_layernorms-true/step_250/vllm_bitblas" # 0.30

# MODEL_NAME = "llama-3-1-8b-instruct-dora-4-2bit-lora_rank-64-base_lr-2e-5-lr_div_factor-1-train_layernorms-true/step_125/vllm_bitblas" # 0.23
# MODEL_NAME = "llama-3-1-8b-instruct-dora-4-2bit-lora_rank-64-base_lr-2e-5-lr_div_factor-1-train_layernorms-true/step_250/vllm_bitblas" # 0.34



# MODEL_NAME = "llama-3-1-8b-instruct-dora-4-2bit-lora_rank-256-base_lr-1e-4-lr_div_factor-10-train_layernorms-true/step_125/vllm_bitblas" # 0.25
# MODEL_NAME = "llama-3-1-8b-instruct-dora-4-2bit-lora_rank-256-base_lr-1e-4-lr_div_factor-10-train_layernorms-true/step_250/vllm_bitblas" # 0.29

# MODEL_NAME = "llama-3-1-8b-instruct-dora-4-2bit-lora_rank-256-base_lr-1e-4-lr_div_factor-3-train_layernorms-true/step_125/vllm_bitblas" # -
# MODEL_NAME = "llama-3-1-8b-instruct-dora-4-2bit-lora_rank-256-base_lr-1e-4-lr_div_factor-3-train_layernorms-true/step_250/vllm_bitblas" # 0.27

# MODEL_NAME = "llama-3-1-8b-instruct-dora-4-2bit-lora_rank-256-base_lr-1e-4-lr_div_factor-1-train_layernorms-true/step_125/vllm_bitblas" # 0.0
# MODEL_NAME = "llama-3-1-8b-instruct-dora-4-2bit-lora_rank-256-base_lr-1e-4-lr_div_factor-1-train_layernorms-true/step_250/vllm_bitblas" # 0.0

# MODEL_NAME = "llama-3-1-8b-instruct-dora-4-2bit-lora_rank-256-base_lr-5e-5-lr_div_factor-10-train_layernorms-true/step_125/vllm_bitblas" # 0.29
# MODEL_NAME = "llama-3-1-8b-instruct-dora-4-2bit-lora_rank-256-base_lr-5e-5-lr_div_factor-10-train_layernorms-true/step_250/vllm_bitblas" # 0.45

# MODEL_NAME = "llama-3-1-8b-instruct-dora-4-2bit-lora_rank-256-base_lr-5e-5-lr_div_factor-3-train_layernorms-true/step_125/vllm_bitblas" # 0.27
# MODEL_NAME = "llama-3-1-8b-instruct-dora-4-2bit-lora_rank-256-base_lr-5e-5-lr_div_factor-3-train_layernorms-true/step_250/vllm_bitblas" # 0.37

# MODEL_NAME = "llama-3-1-8b-instruct-dora-4-2bit-lora_rank-256-base_lr-5e-5-lr_div_factor-1-train_layernorms-true/step_125/vllm_bitblas" # 0.26
# MODEL_NAME = "llama-3-1-8b-instruct-dora-4-2bit-lora_rank-256-base_lr-5e-5-lr_div_factor-1-train_layernorms-true/step_250/vllm_bitblas" # 0.27

# MODEL_NAME = "llama-3-1-8b-instruct-dora-4-2bit-lora_rank-256-base_lr-2e-5-lr_div_factor-10-train_layernorms-true/step_125/vllm_bitblas" # 0.28
# MODEL_NAME = "llama-3-1-8b-instruct-dora-4-2bit-lora_rank-64-base_lr-2e-5-lr_div_factor-10-train_layernorms-true/step_250/vllm_bitblas" # 0.32

# MODEL_NAME = "llama-3-1-8b-instruct-dora-4-2bit-lora_rank-256-base_lr-2e-5-lr_div_factor-3-train_layernorms-true/step_125/vllm_bitblas" # 0.27
# MODEL_NAME = "llama-3-1-8b-instruct-dora-4-2bit-lora_rank-64-base_lr-2e-5-lr_div_factor-3-train_layernorms-true/step_250/vllm_bitblas" # 0.30

# MODEL_NAME = "llama-3-1-8b-instruct-dora-4-2bit-lora_rank-256-base_lr-2e-5-lr_div_factor-1-train_layernorms-true/step_125/vllm_bitblas" # 0.25
# MODEL_NAME = "llama-3-1-8b-instruct-dora-4-2bit-lora_rank-256-base_lr-2e-5-lr_div_factor-1-train_layernorms-true/step_250/vllm_bitblas" # 0.30



# MODEL_NAME = "llama-3-1-8b-instruct-dora-4-2bit-lora_rank-512-base_lr-1e-4-lr_div_factor-10-train_layernorms-true/step_125/vllm_bitblas" # 0.22
# MODEL_NAME = "llama-3-1-8b-instruct-dora-4-2bit-lora_rank-512-base_lr-1e-4-lr_div_factor-10-train_layernorms-true/step_250/vllm_bitblas" # 0.26

# MODEL_NAME = "llama-3-1-8b-instruct-dora-4-2bit-lora_rank-512-base_lr-1e-4-lr_div_factor-3-train_layernorms-true/step_125/vllm_bitblas" # 0.19
# MODEL_NAME = "llama-3-1-8b-instruct-dora-4-2bit-lora_rank-512-base_lr-1e-4-lr_div_factor-3-train_layernorms-true/step_250/vllm_bitblas" # 0.07

# MODEL_NAME = "llama-3-1-8b-instruct-dora-4-2bit-lora_rank-512-base_lr-1e-4-lr_div_factor-1-train_layernorms-true/step_125/vllm_bitblas" # 0.02
# MODEL_NAME = "llama-3-1-8b-instruct-dora-4-2bit-lora_rank-512-base_lr-1e-4-lr_div_factor-1-train_layernorms-true/step_250/vllm_bitblas" # 0.07

# MODEL_NAME = "llama-3-1-8b-instruct-dora-4-2bit-lora_rank-512-base_lr-5e-5-lr_div_factor-10-train_layernorms-true/step_125/vllm_bitblas" # 0.28
# MODEL_NAME = "llama-3-1-8b-instruct-dora-4-2bit-lora_rank-512-base_lr-5e-5-lr_div_factor-10-train_layernorms-true/step_250/vllm_bitblas" # 0.32

# MODEL_NAME = "llama-3-1-8b-instruct-dora-4-2bit-lora_rank-512-base_lr-5e-5-lr_div_factor-3-train_layernorms-true/step_125/vllm_bitblas" # 0.25
# MODEL_NAME = "llama-3-1-8b-instruct-dora-4-2bit-lora_rank-512-base_lr-5e-5-lr_div_factor-3-train_layernorms-true/step_250/vllm_bitblas" # 0.41

# MODEL_NAME = "llama-3-1-8b-instruct-dora-4-2bit-lora_rank-512-base_lr-5e-5-lr_div_factor-1-train_layernorms-true/step_125/vllm_bitblas" # 0.21
# MODEL_NAME = "llama-3-1-8b-instruct-dora-4-2bit-lora_rank-512-base_lr-5e-5-lr_div_factor-1-train_layernorms-true/step_250/vllm_bitblas" # 0.20

# MODEL_NAME = "llama-3-1-8b-instruct-dora-4-2bit-lora_rank-512-base_lr-2e-5-lr_div_factor-10-train_layernorms-true/step_125/vllm_bitblas" # 0.25
# MODEL_NAME = "llama-3-1-8b-instruct-dora-4-2bit-lora_rank-512-base_lr-2e-5-lr_div_factor-10-train_layernorms-true/step_250/vllm_bitblas" # 0.31

# MODEL_NAME = "llama-3-1-8b-instruct-dora-4-2bit-lora_rank-512-base_lr-2e-5-lr_div_factor-3-train_layernorms-true/step_125/vllm_bitblas" # 0.23
# MODEL_NAME = "llama-3-1-8b-instruct-dora-4-2bit-lora_rank-512-base_lr-2e-5-lr_div_factor-3-train_layernorms-true/step_250/vllm_bitblas" # 0.36

# MODEL_NAME = "llama-3-1-8b-instruct-dora-4-2bit-lora_rank-512-base_lr-2e-5-lr_div_factor-1-train_layernorms-true/step_125/vllm_bitblas" # 0.26
# MODEL_NAME = "llama-3-1-8b-instruct-dora-4-2bit-lora_rank-512-base_lr-2e-5-lr_div_factor-1-train_layernorms-true/step_250/vllm_bitblas" # 0.31



# MODEL_NAME = "llama-3-1-8b-instruct-dora-4-2bit-lora_rank-256-base_lr-5e-5-lr_div_factor-10-train_layernorms-true-block-influence/step_125/vllm_bitblas" # 0.41
# MODEL_NAME = "llama-3-1-8b-instruct-dora-4-2bit-lora_rank-256-base_lr-5e-5-lr_div_factor-10-train_layernorms-true-block-influence/step_250/vllm_bitblas" # 0.43
# MODEL_NAME = "llama-3-1-8b-instruct-dora-4-2bit-lora_rank-256-base_lr-5e-5-lr_div_factor-10-train_layernorms-true-block-influence/step_375/vllm_bitblas" # 0.38
# MODEL_NAME = "llama-3-1-8b-instruct-dora-4-2bit-lora_rank-256-base_lr-5e-5-lr_div_factor-10-train_layernorms-true-block-influence/step_500/vllm_bitblas" # 0.46


# MODEL_NAME = "llama-3-1-8b-instruct-dora-4-2bit-lora_rank-256-base_lr-5e-5-lr_div_factor-10-train_layernorms-true-loftq/step_125/vllm_bitblas" # 0.22
# MODEL_NAME = "llama-3-1-8b-instruct-dora-4-2bit-lora_rank-256-base_lr-5e-5-lr_div_factor-10-train_layernorms-true-loftq/step_250/vllm_bitblas" # 0.32
# MODEL_NAME = "llama-3-1-8b-instruct-dora-4-2bit-lora_rank-256-base_lr-5e-5-lr_div_factor-10-train_layernorms-true-loftq/step_375/vllm_bitblas" # 0.32
# MODEL_NAME = "llama-3-1-8b-instruct-dora-4-2bit-lora_rank-256-base_lr-5e-5-lr_div_factor-10-train_layernorms-true-loftq/step_500/vllm_bitblas" # 0.33


# MODEL_NAME = "llama-3-1-8b-instruct-dora-4-2bit-lora_rank-256-base_lr-5e-5-lr_div_factor-10-train_layernorms-true-loftq-block-influence/step_125/vllm_bitblas" # 0.28
# MODEL_NAME = "llama-3-1-8b-instruct-dora-4-2bit-lora_rank-256-base_lr-5e-5-lr_div_factor-10-train_layernorms-true-loftq-block-influence/step_250/vllm_bitblas" # 0.39
# MODEL_NAME = "llama-3-1-8b-instruct-dora-4-2bit-lora_rank-256-base_lr-5e-5-lr_div_factor-10-train_layernorms-true-loftq-block-influence/step_375/vllm_bitblas" # 0.41
# MODEL_NAME = "llama-3-1-8b-instruct-dora-4-2bit-lora_rank-256-base_lr-5e-5-lr_div_factor-10-train_layernorms-true-loftq-block-influence/step_500/vllm_bitblas" # 0.414


# 4 (128) bit 2 (32) bit not adj 20%
# MODEL_NAME = "llama-3-1-8b-instruct-dora-4-2bit-lora_rank-256-base_lr-5e-5-lr_div_factor-10-train_layernorms-true-block-influence-no-adj-20pct/step_125/vllm_bitblas" # 0.
# MODEL_NAME = "llama-3-1-8b-instruct-dora-4-2bit-lora_rank-256-base_lr-5e-5-lr_div_factor-10-train_layernorms-true-block-influence-no-adj-20pct/step_125/merged" # 0.39
# MODEL_NAME = "llama-3-1-8b-instruct-dora-4-2bit-lora_rank-256-base_lr-5e-5-lr_div_factor-10-train_layernorms-true-block-influence-no-adj-20pct/step_250/vllm_bitblas" # 0.
# MODEL_NAME = "llama-3-1-8b-instruct-dora-4-2bit-lora_rank-256-base_lr-5e-5-lr_div_factor-10-train_layernorms-true-block-influence-no-adj-20pct/step_250/merged" # 0.45
# MODEL_NAME = "llama-3-1-8b-instruct-dora-4-2bit-lora_rank-256-base_lr-5e-5-lr_div_factor-10-train_layernorms-true-block-influence-no-adj-20pct/step_375/vllm_bitblas" # 0.
# MODEL_NAME = "llama-3-1-8b-instruct-dora-4-2bit-lora_rank-256-base_lr-5e-5-lr_div_factor-10-train_layernorms-true-block-influence-no-adj-20pct/step_375/merged" # 0.41
# MODEL_NAME = "llama-3-1-8b-instruct-dora-4-2bit-lora_rank-256-base_lr-5e-5-lr_div_factor-10-train_layernorms-true-block-influence-no-adj-20pct/step_500/vllm_bitblas" # 0.
# MODEL_NAME = "llama-3-1-8b-instruct-dora-4-2bit-lora_rank-256-base_lr-5e-5-lr_div_factor-10-train_layernorms-true-block-influence-no-adj-20pct/step_500/merged" # 0.51
# MODEL_NAME = "llama-3-1-8b-instruct-dora-4-2bit-lora_rank-256-base_lr-5e-5-lr_div_factor-10-train_layernorms-true-block-influence-no-adj-20pct/step_625/vllm_bitblas" # 0.
# MODEL_NAME = "llama-3-1-8b-instruct-dora-4-2bit-lora_rank-256-base_lr-5e-5-lr_div_factor-10-train_layernorms-true-block-influence-no-adj-20pct/step_625/merged" # 0.49
# MODEL_NAME = "llama-3-1-8b-instruct-dora-4-2bit-lora_rank-256-base_lr-5e-5-lr_div_factor-10-train_layernorms-true-block-influence-no-adj-20pct/step_750/vllm_bitblas" # 0.
# MODEL_NAME = "llama-3-1-8b-instruct-dora-4-2bit-lora_rank-256-base_lr-5e-5-lr_div_factor-10-train_layernorms-true-block-influence-no-adj-20pct/step_750/merged" # 0.57

# # 4 (128) bit 2 (32) bit adj 20%
# MODEL_NAME = "llama-3-1-8b-instruct-dora-4-2bit-lora_rank-256-base_lr-5e-5-lr_div_factor-10-train_layernorms-true-block-influence-adj-20pct/step_125/vllm_bitblas" # 0.
# MODEL_NAME = "llama-3-1-8b-instruct-dora-4-2bit-lora_rank-256-base_lr-5e-5-lr_div_factor-10-train_layernorms-true-block-influence-adj-20pct/step_125/merged" # 0.44
# MODEL_NAME = "llama-3-1-8b-instruct-dora-4-2bit-lora_rank-256-base_lr-5e-5-lr_div_factor-10-train_layernorms-true-block-influence-adj-20pct/step_250/vllm_bitblas" # 0.
# MODEL_NAME = "llama-3-1-8b-instruct-dora-4-2bit-lora_rank-256-base_lr-5e-5-lr_div_factor-10-train_layernorms-true-block-influence-adj-20pct/step_250/merged" # 0.36
# MODEL_NAME = "llama-3-1-8b-instruct-dora-4-2bit-lora_rank-256-base_lr-5e-5-lr_div_factor-10-train_layernorms-true-block-influence-adj-20pct/step_375/vllm_bitblas" # 0.
# MODEL_NAME = "llama-3-1-8b-instruct-dora-4-2bit-lora_rank-256-base_lr-5e-5-lr_div_factor-10-train_layernorms-true-block-influence-adj-20pct/step_375/merged" # 0.43
# MODEL_NAME = "llama-3-1-8b-instruct-dora-4-2bit-lora_rank-256-base_lr-5e-5-lr_div_factor-10-train_layernorms-true-block-influence-adj-20pct/step_500/vllm_bitblas" # 0.
# MODEL_NAME = "llama-3-1-8b-instruct-dora-4-2bit-lora_rank-256-base_lr-5e-5-lr_div_factor-10-train_layernorms-true-block-influence-adj-20pct/step_500/merged" # 0.45

# # 4 (128) bit 2 (32) bit not adj 30%
# MODEL_NAME = "llama-3-1-8b-instruct-dora-4-2bit-lora_rank-256-base_lr-5e-5-lr_div_factor-10-train_layernorms-true-block-influence-no-adj-30pct/step_125/vllm_bitblas" # 0.
# MODEL_NAME = "llama-3-1-8b-instruct-dora-4-2bit-lora_rank-256-base_lr-5e-5-lr_div_factor-10-train_layernorms-true-block-influence-no-adj-30pct/step_125/merged" # 0.41
# MODEL_NAME = "llama-3-1-8b-instruct-dora-4-2bit-lora_rank-256-base_lr-5e-5-lr_div_factor-10-train_layernorms-true-block-influence-no-adj-30pct/step_250/vllm_bitblas" # 0.
# MODEL_NAME = "llama-3-1-8b-instruct-dora-4-2bit-lora_rank-256-base_lr-5e-5-lr_div_factor-10-train_layernorms-true-block-influence-no-adj-30pct/step_250/merged" # 0.48

# # 4 (128) bit 2 (32) bit adj 30%
# MODEL_NAME = "llama-3-1-8b-instruct-dora-4-2bit-lora_rank-256-base_lr-5e-5-lr_div_factor-10-train_layernorms-true-block-influence-adj-30pct/step_125/vllm_bitblas" # 0.
MODEL_NAME = "llama-3-1-8b-instruct-dora-4-2bit-lora_rank-256-base_lr-5e-5-lr_div_factor-10-train_layernorms-true-block-influence-adj-30pct/step_125/merged" # 0.


In [19]:
llm = LLM(
    	  model=os.path.join(model_dir,MODEL_NAME), 
          tokenizer=TOKENIZER, 
          tensor_parallel_size=NUM_GPUS, 
          max_model_len=8192,
        #   quantization="bitblas",
          dtype="bfloat16",
		  gpu_memory_utilization=0.8,
    	#   enforce_eager=True,
          )

INFO 08-28 15:19:10 llm_engine.py:176] Initializing an LLM engine (v0.5.3.post1) with config: model='/workspace/models/llama-3-1-8b-instruct-dora-4-2bit-lora_rank-256-base_lr-5e-5-lr_div_factor-10-train_layernorms-true-block-influence-adj-30pct/step_125/merged', speculative_config=None, tokenizer='meta-llama/Meta-Llama-3.1-8B-Instruct', skip_tokenizer_init=False, tokenizer_mode=auto, revision=None, rope_scaling=None, rope_theta=None, tokenizer_revision=None, trust_remote_code=False, dtype=torch.bfloat16, max_seq_len=8192, download_dir=None, load_format=LoadFormat.AUTO, tensor_parallel_size=1, pipeline_parallel_size=1, disable_custom_all_reduce=False, quantization=None, enforce_eager=False, kv_cache_dtype=auto, quantization_param_path=None, device_config=cuda, decoding_config=DecodingConfig(guided_decoding_backend='outlines'), observability_config=ObservabilityConfig(otlp_traces_endpoint=None), seed=0, served_model_name=/workspace/models/llama-3-1-8b-instruct-dora-4-2bit-lora_rank-256-b

Loading safetensors checkpoint shards:   0% Completed | 0/4 [00:00<?, ?it/s]


Loading model.embed_tokens.weight
Loaded model.embed_tokens.weight
Loading model.layers.0.input_layernorm.weight
Loaded model.layers.0.input_layernorm.weight
Loading model.layers.0.mlp.down_proj.weight
Loaded model.layers.0.mlp.down_proj.weight
Loading model.layers.0.mlp.gate_proj.weight
Loaded model.layers.0.mlp.gate_up_proj.weight
Loading model.layers.0.mlp.up_proj.weight
Loaded model.layers.0.mlp.gate_up_proj.weight
Loading model.layers.0.post_attention_layernorm.weight
Loaded model.layers.0.post_attention_layernorm.weight
Loading model.layers.0.self_attn.k_proj.weight
Loaded model.layers.0.self_attn.qkv_proj.weight
Loading model.layers.0.self_attn.o_proj.weight
Loaded model.layers.0.self_attn.o_proj.weight
Loading model.layers.0.self_attn.q_proj.weight
Loaded model.layers.0.self_attn.qkv_proj.weight
Loading model.layers.0.self_attn.v_proj.weight
Loaded model.layers.0.self_attn.qkv_proj.weight
Loading model.layers.1.input_layernorm.weight
Loaded model.layers.1.input_layernorm.weight

In [20]:
# base model
# outputs = llm.generate(inputs, SamplingParams(temperature=0.0, max_tokens=1024))

In [21]:
# chat model
outputs = llm.generate(chat_inputs[:128], SamplingParams(temperature=0.0, max_tokens=1024, stop=["<|eot_id|>"]))

Processed prompts:   0%|          | 0/128 [00:00<?, ?it/s, est. speed input: 0.00 toks/s, output: 0.00 toks/s]

Processed prompts: 100%|██████████| 128/128 [00:19<00:00,  6.53it/s, est. speed input: 706.05 toks/s, output: 1617.90 toks/s] 


In [22]:
short_answers_pred = [extract_last_number_or_ratio(o.outputs[0].text) for o in outputs]

In [23]:
exact_match_score(short_answers_pred, labels)

0.421875

In [None]:
chat_input = [convert_to_chat_input("Write python code for image classification using HF transformers library. Start with `from transformers import ...`")]
outputs = llm.generate(chat_input, SamplingParams(temperature=0.0, max_tokens=1024, stop=["<|eot_id|>"],
                                                  frequency_penalty=0.3, presence_penalty=0.5))
print(outputs[0].outputs[0].text)

In [39]:
chat_input = [convert_to_chat_input("List top 10 scientist from 15th century as numbered list with their dob and death. Inclde da vinci.")]
outputs = llm.generate(chat_input, SamplingParams(temperature=0.0, max_tokens=256, stop=["<|eot_id|>"],
                                                  frequency_penalty=0.5, presence_penalty=0.))
print(outputs[0].outputs[0].text)

Processed prompts: 100%|██████████| 1/1 [00:01<00:00,  1.82s/it, est. speed input: 41.34 toks/s, output: 77.16 toks/s]

1. Leonardo da Vinci (1452-1516)
2. Alhunghar Nasimi (1450-1515)
3. Giovanni Battista Belli (1474-1516)
4. Giovanni Battista Cusano (1450-1508)
5. Pietro da Cortona (1596-1635)
6. Michelangelo Buonarroti (1478-1564)
7. Raphael Sanzio (1503-1522)
8. Andrea del Verrocchio (1470-1487)
9. Sandi Baccalini (1460s - 1480s) 
10. Leonardo of Fiesole





#### N-SHOT

In [None]:
MODEL_NAME = "meta-llama/Meta-Llama-3-8B-Instruct"
llm = LLM(model=MODEL_NAME, tensor_parallel_size=NUM_GPUS, dtype="bfloat16")

In [17]:
# zero-shot
inputs = [f"###Question:\n{question}\n###Answer:\n" for question in valid_dataset['question']]
outputs = llm.generate(inputs, SamplingParams(temperature=0.0, max_tokens=1024))
short_answers_pred = [extract_last_number_or_ratio(o.outputs[0].text) for o in outputs]
exact_match_score(short_answers_pred, labels)

Processed prompts: 100%|██████████| 500/500 [01:20<00:00,  6.18it/s]


0.228

In [17]:
# zero-shot (instruct)
outputs = llm.generate(chat_inputs, SamplingParams(temperature=0.0, max_tokens=1024))
short_answers_pred = [extract_last_number_or_ratio(o.outputs[0].text) for o in outputs]
exact_match_score(short_answers_pred, labels)

Processed prompts: 100%|██████████| 500/500 [01:07<00:00,  7.41it/s]


0.454

In [None]:
# 5-shot
few_shot_examples = [f"###Question:\n{ex['question']}\n###Answer:\n{ex['answer']}<stop>" for ex in 
                     dataset.select(range(len(dataset)-5,len(dataset)))]
few_shot_prompt = "\n\n".join(few_shot_examples)
inputs = [few_shot_prompt + "\n\n" + f"###Question:\n{question}\n###Answer:\n" for question in valid_dataset['question']]
outputs = llm.generate(inputs, SamplingParams(temperature=0.0, 
                                              stop_token_ids=[tokenizer.eos_token_id], 
                                              stop=["<stop>"], 
                                              max_tokens=1024))
short_answers_pred = [extract_last_number_or_ratio(o.outputs[0].text) for o in outputs]
exact_match_score(short_answers_pred, labels)

Processed prompts:   9%|▉         | 47/500 [01:06<02:45,  2.73it/s] 

In [30]:
few_shot_examples = [f"###Question:\n{ex['question']}\n###Answer:\n{ex['answer']}<stop>" for ex in 
                     dataset.select(range(len(dataset)-5,len(dataset)))]

In [32]:
few_shot_prompt = "\n\n".join(few_shot_examples)

In [37]:
def fewshot_chat_input(question, answer=None):
    messages = [
        {"role": "system", "content": f"You are an AI assistant that excels in solving math problems. Here are few examples of math problems:\n\n{few_shot_prompt}"},
        {"role": "user", "content": question},
    ]
    return tokenizer.apply_chat_template(messages, tokenize=False, add_generation_prompt=True)

In [38]:
fewshot_chat_inputs = [fewshot_chat_input(question)  for question in valid_dataset['question']]

In [40]:
outputs = llm.generate(fewshot_chat_inputs, SamplingParams(temperature=0.0, 
                                              stop_token_ids=[tokenizer.eos_token_id], 
                                              stop=["<|eot_id|>"],
                                              max_tokens=1024))
short_answers_pred = [extract_last_number_or_ratio(o.outputs[0].text) for o in outputs]
exact_match_score(short_answers_pred, labels)

Processed prompts: 100%|██████████| 500/500 [02:32<00:00,  3.27it/s]


0.452