In [1]:
%load_ext autoreload
%autoreload 2

In [2]:
import os

os.environ["CUDA_VISIBLE_DEVICES"] = "6"

In [3]:
import awq

In [4]:
import tqdm
import torch
from torch import nn
from transformers import AutoModelForCausalLM, AutoTokenizer
from datasets import load_dataset
from functools import partial
import gc

In [5]:
def evaluate(model, tokenizer):
    testenc = load_dataset("wikitext", "wikitext-2-raw-v1", split="test")
    testenc = tokenizer("\n\n".join(testenc["text"]), return_tensors="pt")

    testenc = testenc.input_ids.to(model.device)
    nsamples = 40
    model = model.eval()

    nlls = []
    for i in tqdm.tqdm(range(nsamples), desc="evaluating..."):
        batch = testenc[:, (i * 2048) : ((i + 1) * 2048)].to(model.device)
        with torch.no_grad():
            lm_logits = model(batch).logits
        shift_logits = lm_logits[:, :-1, :].contiguous().float()
        shift_labels = testenc[:, (i * 2048) : ((i + 1) * 2048)][:, 1:]
        loss_fct = nn.CrossEntropyLoss()
        loss = loss_fct(
            shift_logits.view(-1, shift_logits.size(-1)), shift_labels.view(-1)
        )
        neg_log_likelihood = loss.float() * 2048
        nlls.append(neg_log_likelihood)

    return torch.exp(torch.stack(nlls).sum() / (nsamples * 2048))

In [6]:
def get_model_size(model: nn.Module, data_width=16, group_size=-1):

    if group_size != -1:
        data_width += (16 + 4) / group_size

    num_elements = 0
    for param in model.parameters():
        num_elements += param.numel()
    return num_elements * data_width


Byte = 8
KiB = 1024 * Byte
MiB = 1024 * KiB
GiB = 1024 * MiB

In [7]:
model_path = "facebook/opt-1.3b"

# model_path = "facebook/opt-13b"

tokenizer = AutoTokenizer.from_pretrained(model_path, use_fast=False)



In [8]:
original_model_n_bits = 32
torch_dtype = torch.float16 if original_model_n_bits == 16 else torch.float32

model = AutoModelForCausalLM.from_pretrained(
    model_path, torch_dtype=torch_dtype, device_map="auto"
)



In [9]:
# Evaluate the model
model_perplexity = evaluate(model, tokenizer)
model_size = get_model_size(model, data_width=original_model_n_bits, group_size=128)

### Print the results
print(f"\nmodel perplexity: {model_perplexity:.2f}")
print(f"model size: {model_size/MiB:.2f} MiB")

evaluating...: 100%|██████████| 40/40 [00:20<00:00,  1.93it/s]



model perplexity: 14.47
model size: 5043.73 MiB


In [48]:
from awq.quantize.pre_quant_lester import run_awq, apply_awq
from typing import Literal

ActQuantType = Literal["per_token", "per_tensor", "none"]

q_config = {
    "zero_point": True,  # by default True
    "q_group_size": 256,  # whether to use group quantization
    "w_n_bits": 4,
    "a_n_bits": 4,
    "act_quant": "per_tensor",
}

model = AutoModelForCausalLM.from_pretrained(
    model_path, torch_dtype=torch_dtype, device_map="auto"
)

awq_results = run_awq(
    model,
    tokenizer,
    w_bit=q_config["w_n_bits"],
    q_config=q_config,
    n_samples=128,
    seqlen=512,
)

Repo card metadata block was not found. Setting CardData to empty.


 * Split into 60 blocks


Running AWQ...:   0%|          | 0/24 [00:00<?, ?it/s]

dict_keys(['self_attn.q_proj', 'self_attn.k_proj', 'self_attn.v_proj', 'self_attn.out_proj', 'fc1', 'fc2'])
dict_keys(['self_attn.q_proj', 'self_attn.k_proj', 'self_attn.v_proj', 'self_attn.out_proj', 'fc1', 'fc2'])
----------
ratio: 0.0, loss: 1.0621346235275269
ratio: 0.05, loss: 1.0885839462280273
ratio: 0.1, loss: 0.985962986946106
ratio: 0.15, loss: 0.9112263917922974
ratio: 0.2, loss: 0.9200567603111267
ratio: 0.25, loss: 0.9309477806091309
ratio: 0.3, loss: 0.9278945326805115
ratio: 0.35, loss: 0.9316182732582092
ratio: 0.4, loss: 0.9297643303871155
ratio: 0.45, loss: 0.9304285049438477
ratio: 0.5, loss: 0.9303783774375916
ratio: 0.55, loss: 0.9310886859893799
ratio: 0.6, loss: 0.9300938248634338
ratio: 0.65, loss: 0.930308997631073
ratio: 0.7, loss: 0.9299536347389221
ratio: 0.75, loss: 0.9301728010177612
ratio: 0.8, loss: 0.9302728772163391
ratio: 0.85, loss: 0.9302891492843628
ratio: 0.9, loss: 0.9302465319633484
ratio: 0.95, loss: 0.9301099181175232
ratio: 0.0, loss: 0.01421

Running AWQ...:   4%|▍         | 1/24 [00:08<03:18,  8.62s/it]

dict_keys(['self_attn.q_proj', 'self_attn.k_proj', 'self_attn.v_proj', 'self_attn.out_proj', 'fc1', 'fc2'])
dict_keys(['self_attn.q_proj', 'self_attn.k_proj', 'self_attn.v_proj', 'self_attn.out_proj', 'fc1', 'fc2'])
----------
ratio: 0.0, loss: 1.64265775680542
ratio: 0.05, loss: 1.535681962966919
ratio: 0.1, loss: 1.5437500476837158
ratio: 0.15, loss: 1.539994239807129
ratio: 0.2, loss: 1.5368536710739136
ratio: 0.25, loss: 1.5346659421920776
ratio: 0.3, loss: 1.5365595817565918
ratio: 0.35, loss: 1.535537600517273
ratio: 0.4, loss: 1.5344001054763794
ratio: 0.45, loss: 1.5348025560379028
ratio: 0.5, loss: 1.5346769094467163
ratio: 0.55, loss: 1.5339692831039429
ratio: 0.6, loss: 1.5350390672683716
ratio: 0.65, loss: 1.534542202949524
ratio: 0.7, loss: 1.534589171409607
ratio: 0.75, loss: 1.5331711769104004
ratio: 0.8, loss: 1.5354747772216797
ratio: 0.85, loss: 1.537583827972412
ratio: 0.9, loss: 1.5354299545288086
ratio: 0.95, loss: 1.5298359394073486
ratio: 0.0, loss: 0.00264464202

Running AWQ...:   8%|▊         | 2/24 [00:15<02:50,  7.77s/it]

dict_keys(['self_attn.q_proj', 'self_attn.k_proj', 'self_attn.v_proj', 'self_attn.out_proj', 'fc1', 'fc2'])
dict_keys(['self_attn.q_proj', 'self_attn.k_proj', 'self_attn.v_proj', 'self_attn.out_proj', 'fc1', 'fc2'])
----------
ratio: 0.0, loss: 1.825465202331543
ratio: 0.05, loss: 1.791100025177002
ratio: 0.1, loss: 1.7546273469924927
ratio: 0.15, loss: 1.7486554384231567
ratio: 0.2, loss: 1.7426435947418213
ratio: 0.25, loss: 1.7372469902038574
ratio: 0.3, loss: 1.73504638671875
ratio: 0.35, loss: 1.7342441082000732
ratio: 0.4, loss: 1.7327215671539307
ratio: 0.45, loss: 1.7342054843902588
ratio: 0.5, loss: 1.7330964803695679
ratio: 0.55, loss: 1.7335439920425415
ratio: 0.6, loss: 1.7333019971847534
ratio: 0.65, loss: 1.732922911643982
ratio: 0.7, loss: 1.7323499917984009
ratio: 0.75, loss: 1.7320928573608398
ratio: 0.8, loss: 1.7325963973999023
ratio: 0.85, loss: 1.7343332767486572
ratio: 0.9, loss: 1.7355806827545166
ratio: 0.95, loss: 1.7347527742385864
ratio: 0.0, loss: 0.00140699

Running AWQ...:  12%|█▎        | 3/24 [00:23<02:38,  7.57s/it]

dict_keys(['self_attn.q_proj', 'self_attn.k_proj', 'self_attn.v_proj', 'self_attn.out_proj', 'fc1', 'fc2'])
dict_keys(['self_attn.q_proj', 'self_attn.k_proj', 'self_attn.v_proj', 'self_attn.out_proj', 'fc1', 'fc2'])
----------
ratio: 0.0, loss: 1.9574739933013916
ratio: 0.05, loss: 1.9716770648956299
ratio: 0.1, loss: 1.9790089130401611
ratio: 0.15, loss: 1.9727791547775269
ratio: 0.2, loss: 1.9674952030181885
ratio: 0.25, loss: 1.9639595746994019
ratio: 0.3, loss: 1.9603599309921265
ratio: 0.35, loss: 1.9601446390151978
ratio: 0.4, loss: 1.9605523347854614
ratio: 0.45, loss: 1.9605144262313843
ratio: 0.5, loss: 1.9602103233337402
ratio: 0.55, loss: 1.9609137773513794
ratio: 0.6, loss: 1.960694432258606
ratio: 0.65, loss: 1.9606273174285889
ratio: 0.7, loss: 1.9609429836273193
ratio: 0.75, loss: 1.9601742029190063
ratio: 0.8, loss: 1.9603244066238403
ratio: 0.85, loss: 1.9601202011108398
ratio: 0.9, loss: 1.960702657699585
ratio: 0.95, loss: 1.9604928493499756
ratio: 0.0, loss: 0.00082

Running AWQ...:  17%|█▋        | 4/24 [00:30<02:29,  7.46s/it]

dict_keys(['self_attn.q_proj', 'self_attn.k_proj', 'self_attn.v_proj', 'self_attn.out_proj', 'fc1', 'fc2'])
dict_keys(['self_attn.q_proj', 'self_attn.k_proj', 'self_attn.v_proj', 'self_attn.out_proj', 'fc1', 'fc2'])
----------
ratio: 0.0, loss: 1.985856533050537
ratio: 0.05, loss: 1.9633346796035767
ratio: 0.1, loss: 1.9441944360733032
ratio: 0.15, loss: 1.9383623600006104
ratio: 0.2, loss: 1.927996277809143
ratio: 0.25, loss: 1.926278829574585
ratio: 0.3, loss: 1.9228038787841797
ratio: 0.35, loss: 1.9235639572143555
ratio: 0.4, loss: 1.9224259853363037
ratio: 0.45, loss: 1.9227455854415894
ratio: 0.5, loss: 1.9222346544265747
ratio: 0.55, loss: 1.9228142499923706
ratio: 0.6, loss: 1.9221203327178955
ratio: 0.65, loss: 1.922487735748291
ratio: 0.7, loss: 1.9228153228759766
ratio: 0.75, loss: 1.9225214719772339
ratio: 0.8, loss: 1.9226495027542114
ratio: 0.85, loss: 1.9222967624664307
ratio: 0.9, loss: 1.921908974647522
ratio: 0.95, loss: 1.9216525554656982
ratio: 0.0, loss: 0.00083988

Running AWQ...:  21%|██        | 5/24 [00:37<02:20,  7.41s/it]

dict_keys(['self_attn.q_proj', 'self_attn.k_proj', 'self_attn.v_proj', 'self_attn.out_proj', 'fc1', 'fc2'])
dict_keys(['self_attn.q_proj', 'self_attn.k_proj', 'self_attn.v_proj', 'self_attn.out_proj', 'fc1', 'fc2'])
----------
ratio: 0.0, loss: 2.2868988513946533
ratio: 0.05, loss: 2.2655770778656006
ratio: 0.1, loss: 2.2535817623138428
ratio: 0.15, loss: 2.2408225536346436
ratio: 0.2, loss: 2.2341227531433105
ratio: 0.25, loss: 2.225645065307617
ratio: 0.3, loss: 2.2204086780548096
ratio: 0.35, loss: 2.21950364112854
ratio: 0.4, loss: 2.2178099155426025
ratio: 0.45, loss: 2.2204599380493164
ratio: 0.5, loss: 2.219845771789551
ratio: 0.55, loss: 2.2196133136749268
ratio: 0.6, loss: 2.2196242809295654
ratio: 0.65, loss: 2.219454050064087
ratio: 0.7, loss: 2.21968674659729
ratio: 0.75, loss: 2.2198264598846436
ratio: 0.8, loss: 2.220099449157715
ratio: 0.85, loss: 2.2196080684661865
ratio: 0.9, loss: 2.219183921813965
ratio: 0.95, loss: 2.2188467979431152
ratio: 0.0, loss: 0.001232846290

Running AWQ...:  25%|██▌       | 6/24 [00:45<02:14,  7.48s/it]

dict_keys(['self_attn.q_proj', 'self_attn.k_proj', 'self_attn.v_proj', 'self_attn.out_proj', 'fc1', 'fc2'])
dict_keys(['self_attn.q_proj', 'self_attn.k_proj', 'self_attn.v_proj', 'self_attn.out_proj', 'fc1', 'fc2'])
----------
ratio: 0.0, loss: 2.3185579776763916
ratio: 0.05, loss: 2.2887182235717773
ratio: 0.1, loss: 2.2770493030548096
ratio: 0.15, loss: 2.2701940536499023
ratio: 0.2, loss: 2.2588024139404297
ratio: 0.25, loss: 2.255110263824463
ratio: 0.3, loss: 2.2513561248779297
ratio: 0.35, loss: 2.2486300468444824
ratio: 0.4, loss: 2.2479629516601562
ratio: 0.45, loss: 2.2466881275177
ratio: 0.5, loss: 2.2477269172668457
ratio: 0.55, loss: 2.2474398612976074
ratio: 0.6, loss: 2.247495174407959
ratio: 0.65, loss: 2.24777889251709
ratio: 0.7, loss: 2.2478461265563965
ratio: 0.75, loss: 2.2479162216186523
ratio: 0.8, loss: 2.247710943222046
ratio: 0.85, loss: 2.247565984725952
ratio: 0.9, loss: 2.247483253479004
ratio: 0.95, loss: 2.2470617294311523
ratio: 0.0, loss: 0.0016170025337

Running AWQ...:  29%|██▉       | 7/24 [00:52<02:04,  7.30s/it]

dict_keys(['self_attn.q_proj', 'self_attn.k_proj', 'self_attn.v_proj', 'self_attn.out_proj', 'fc1', 'fc2'])
dict_keys(['self_attn.q_proj', 'self_attn.k_proj', 'self_attn.v_proj', 'self_attn.out_proj', 'fc1', 'fc2'])
----------
ratio: 0.0, loss: 2.5163135528564453
ratio: 0.05, loss: 2.496886968612671
ratio: 0.1, loss: 2.4939968585968018
ratio: 0.15, loss: 2.4915571212768555
ratio: 0.2, loss: 2.482785701751709
ratio: 0.25, loss: 2.4780797958374023
ratio: 0.3, loss: 2.47410249710083
ratio: 0.35, loss: 2.4721992015838623
ratio: 0.4, loss: 2.471640110015869
ratio: 0.45, loss: 2.471910238265991
ratio: 0.5, loss: 2.472386360168457
ratio: 0.55, loss: 2.4709279537200928
ratio: 0.6, loss: 2.4712703227996826
ratio: 0.65, loss: 2.4713664054870605
ratio: 0.7, loss: 2.4718775749206543
ratio: 0.75, loss: 2.471503257751465
ratio: 0.8, loss: 2.4717488288879395
ratio: 0.85, loss: 2.4716410636901855
ratio: 0.9, loss: 2.471071243286133
ratio: 0.95, loss: 2.4710094928741455
ratio: 0.0, loss: 0.001598772360

Running AWQ...:  33%|███▎      | 8/24 [00:59<01:58,  7.41s/it]

dict_keys(['self_attn.q_proj', 'self_attn.k_proj', 'self_attn.v_proj', 'self_attn.out_proj', 'fc1', 'fc2'])
dict_keys(['self_attn.q_proj', 'self_attn.k_proj', 'self_attn.v_proj', 'self_attn.out_proj', 'fc1', 'fc2'])
----------
ratio: 0.0, loss: 2.654538154602051
ratio: 0.05, loss: 2.6555469036102295
ratio: 0.1, loss: 2.6598780155181885
ratio: 0.15, loss: 2.6628592014312744
ratio: 0.2, loss: 2.656491994857788
ratio: 0.25, loss: 2.652235269546509
ratio: 0.3, loss: 2.6524569988250732
ratio: 0.35, loss: 2.650285482406616
ratio: 0.4, loss: 2.648380756378174
ratio: 0.45, loss: 2.6479406356811523
ratio: 0.5, loss: 2.6482651233673096
ratio: 0.55, loss: 2.647615432739258
ratio: 0.6, loss: 2.647894859313965
ratio: 0.65, loss: 2.6476621627807617
ratio: 0.7, loss: 2.648118257522583
ratio: 0.75, loss: 2.647925615310669
ratio: 0.8, loss: 2.6482439041137695
ratio: 0.85, loss: 2.6481919288635254
ratio: 0.9, loss: 2.6482036113739014
ratio: 0.95, loss: 2.6477084159851074
ratio: 0.0, loss: 0.002447569044

Running AWQ...:  38%|███▊      | 9/24 [01:08<01:54,  7.64s/it]

dict_keys(['self_attn.q_proj', 'self_attn.k_proj', 'self_attn.v_proj', 'self_attn.out_proj', 'fc1', 'fc2'])
dict_keys(['self_attn.q_proj', 'self_attn.k_proj', 'self_attn.v_proj', 'self_attn.out_proj', 'fc1', 'fc2'])
----------
ratio: 0.0, loss: 2.7060134410858154
ratio: 0.05, loss: 2.7114474773406982
ratio: 0.1, loss: 2.7218079566955566
ratio: 0.15, loss: 2.716007947921753
ratio: 0.2, loss: 2.7154579162597656
ratio: 0.25, loss: 2.7086870670318604
ratio: 0.3, loss: 2.709371566772461
ratio: 0.35, loss: 2.709717035293579
ratio: 0.4, loss: 2.709404468536377
ratio: 0.45, loss: 2.709153413772583
ratio: 0.5, loss: 2.7079591751098633
ratio: 0.55, loss: 2.7091243267059326
ratio: 0.6, loss: 2.708437204360962
ratio: 0.65, loss: 2.7078776359558105
ratio: 0.7, loss: 2.708095073699951
ratio: 0.75, loss: 2.708505153656006
ratio: 0.8, loss: 2.7083613872528076
ratio: 0.85, loss: 2.708491086959839
ratio: 0.9, loss: 2.7086994647979736
ratio: 0.95, loss: 2.708580732345581
ratio: 0.0, loss: 0.0034953111317

Running AWQ...:  42%|████▏     | 10/24 [01:15<01:45,  7.55s/it]

dict_keys(['self_attn.q_proj', 'self_attn.k_proj', 'self_attn.v_proj', 'self_attn.out_proj', 'fc1', 'fc2'])
dict_keys(['self_attn.q_proj', 'self_attn.k_proj', 'self_attn.v_proj', 'self_attn.out_proj', 'fc1', 'fc2'])
----------
ratio: 0.0, loss: 2.7317423820495605
ratio: 0.05, loss: 2.7347116470336914
ratio: 0.1, loss: 2.738438606262207
ratio: 0.15, loss: 2.730745553970337
ratio: 0.2, loss: 2.7315075397491455
ratio: 0.25, loss: 2.7278969287872314
ratio: 0.3, loss: 2.7290046215057373
ratio: 0.35, loss: 2.72650408744812
ratio: 0.4, loss: 2.725104331970215
ratio: 0.45, loss: 2.725640296936035
ratio: 0.5, loss: 2.725955009460449
ratio: 0.55, loss: 2.726181983947754
ratio: 0.6, loss: 2.72602915763855
ratio: 0.65, loss: 2.7259411811828613
ratio: 0.7, loss: 2.7254600524902344
ratio: 0.75, loss: 2.725958824157715
ratio: 0.8, loss: 2.726091146469116
ratio: 0.85, loss: 2.7257375717163086
ratio: 0.9, loss: 2.726215362548828
ratio: 0.95, loss: 2.7258546352386475
ratio: 0.0, loss: 0.0048964116722345

Running AWQ...:  46%|████▌     | 11/24 [01:22<01:35,  7.37s/it]

dict_keys(['self_attn.q_proj', 'self_attn.k_proj', 'self_attn.v_proj', 'self_attn.out_proj', 'fc1', 'fc2'])
dict_keys(['self_attn.q_proj', 'self_attn.k_proj', 'self_attn.v_proj', 'self_attn.out_proj', 'fc1', 'fc2'])
----------
ratio: 0.0, loss: 2.783752679824829
ratio: 0.05, loss: 2.779613494873047
ratio: 0.1, loss: 2.7823479175567627
ratio: 0.15, loss: 2.781463861465454
ratio: 0.2, loss: 2.7784857749938965
ratio: 0.25, loss: 2.782078742980957
ratio: 0.3, loss: 2.778337240219116
ratio: 0.35, loss: 2.7787346839904785
ratio: 0.4, loss: 2.7799720764160156
ratio: 0.45, loss: 2.778958797454834
ratio: 0.5, loss: 2.779310703277588
ratio: 0.55, loss: 2.779071092605591
ratio: 0.6, loss: 2.7790849208831787
ratio: 0.65, loss: 2.778714179992676
ratio: 0.7, loss: 2.7793240547180176
ratio: 0.75, loss: 2.779846429824829
ratio: 0.8, loss: 2.7797625064849854
ratio: 0.85, loss: 2.7791481018066406
ratio: 0.9, loss: 2.779136896133423
ratio: 0.95, loss: 2.7797396183013916
ratio: 0.0, loss: 0.00538439303636

Running AWQ...:  50%|█████     | 12/24 [01:29<01:28,  7.35s/it]

dict_keys(['self_attn.q_proj', 'self_attn.k_proj', 'self_attn.v_proj', 'self_attn.out_proj', 'fc1', 'fc2'])
dict_keys(['self_attn.q_proj', 'self_attn.k_proj', 'self_attn.v_proj', 'self_attn.out_proj', 'fc1', 'fc2'])
----------
ratio: 0.0, loss: 2.537404775619507
ratio: 0.05, loss: 2.5349228382110596
ratio: 0.1, loss: 2.5293986797332764
ratio: 0.15, loss: 2.530937910079956
ratio: 0.2, loss: 2.535053014755249
ratio: 0.25, loss: 2.534853219985962
ratio: 0.3, loss: 2.5344443321228027
ratio: 0.35, loss: 2.5345876216888428
ratio: 0.4, loss: 2.5347743034362793
ratio: 0.45, loss: 2.534226655960083
ratio: 0.5, loss: 2.534853935241699
ratio: 0.55, loss: 2.535165309906006
ratio: 0.6, loss: 2.534177780151367
ratio: 0.65, loss: 2.5353426933288574
ratio: 0.7, loss: 2.5345981121063232
ratio: 0.75, loss: 2.5344793796539307
ratio: 0.8, loss: 2.5344018936157227
ratio: 0.85, loss: 2.5344228744506836
ratio: 0.9, loss: 2.534838914871216
ratio: 0.95, loss: 2.534512519836426
ratio: 0.0, loss: 0.0069630267098

Running AWQ...:  54%|█████▍    | 13/24 [01:36<01:19,  7.25s/it]

dict_keys(['self_attn.q_proj', 'self_attn.k_proj', 'self_attn.v_proj', 'self_attn.out_proj', 'fc1', 'fc2'])
dict_keys(['self_attn.q_proj', 'self_attn.k_proj', 'self_attn.v_proj', 'self_attn.out_proj', 'fc1', 'fc2'])
----------
ratio: 0.0, loss: 2.687138795852661
ratio: 0.05, loss: 2.688307762145996
ratio: 0.1, loss: 2.6940786838531494
ratio: 0.15, loss: 2.7003378868103027
ratio: 0.2, loss: 2.703895092010498
ratio: 0.25, loss: 2.707336664199829
ratio: 0.3, loss: 2.706421375274658
ratio: 0.35, loss: 2.7074973583221436
ratio: 0.4, loss: 2.709094285964966
ratio: 0.45, loss: 2.7076683044433594
ratio: 0.5, loss: 2.7082419395446777
ratio: 0.55, loss: 2.7080767154693604
ratio: 0.6, loss: 2.706428050994873
ratio: 0.65, loss: 2.707422971725464
ratio: 0.7, loss: 2.7080237865448
ratio: 0.75, loss: 2.707259178161621
ratio: 0.8, loss: 2.7066714763641357
ratio: 0.85, loss: 2.7066707611083984
ratio: 0.9, loss: 2.7074294090270996
ratio: 0.95, loss: 2.7077457904815674
ratio: 0.0, loss: 0.008159367367625

Running AWQ...:  58%|█████▊    | 14/24 [01:44<01:13,  7.34s/it]

dict_keys(['self_attn.q_proj', 'self_attn.k_proj', 'self_attn.v_proj', 'self_attn.out_proj', 'fc1', 'fc2'])
dict_keys(['self_attn.q_proj', 'self_attn.k_proj', 'self_attn.v_proj', 'self_attn.out_proj', 'fc1', 'fc2'])
----------
ratio: 0.0, loss: 2.8968300819396973
ratio: 0.05, loss: 2.902735471725464
ratio: 0.1, loss: 2.9070491790771484
ratio: 0.15, loss: 2.9120380878448486
ratio: 0.2, loss: 2.9160549640655518
ratio: 0.25, loss: 2.9147000312805176
ratio: 0.3, loss: 2.918846845626831
ratio: 0.35, loss: 2.9178683757781982
ratio: 0.4, loss: 2.9187827110290527
ratio: 0.45, loss: 2.919306993484497
ratio: 0.5, loss: 2.918199300765991
ratio: 0.55, loss: 2.9181621074676514
ratio: 0.6, loss: 2.918179750442505
ratio: 0.65, loss: 2.917536497116089
ratio: 0.7, loss: 2.917544364929199
ratio: 0.75, loss: 2.917771577835083
ratio: 0.8, loss: 2.917417049407959
ratio: 0.85, loss: 2.9181201457977295
ratio: 0.9, loss: 2.917375087738037
ratio: 0.95, loss: 2.9175758361816406
ratio: 0.0, loss: 0.0104693751782

Running AWQ...:  62%|██████▎   | 15/24 [01:52<01:08,  7.65s/it]

dict_keys(['self_attn.q_proj', 'self_attn.k_proj', 'self_attn.v_proj', 'self_attn.out_proj', 'fc1', 'fc2'])
dict_keys(['self_attn.q_proj', 'self_attn.k_proj', 'self_attn.v_proj', 'self_attn.out_proj', 'fc1', 'fc2'])
----------
ratio: 0.0, loss: 3.0677976608276367
ratio: 0.05, loss: 3.0739142894744873
ratio: 0.1, loss: 3.080613613128662
ratio: 0.15, loss: 3.0846803188323975
ratio: 0.2, loss: 3.0874507427215576
ratio: 0.25, loss: 3.0883989334106445
ratio: 0.3, loss: 3.0882468223571777
ratio: 0.35, loss: 3.089376926422119
ratio: 0.4, loss: 3.0869388580322266
ratio: 0.45, loss: 3.0883781909942627
ratio: 0.5, loss: 3.087677240371704
ratio: 0.55, loss: 3.0878541469573975
ratio: 0.6, loss: 3.087907075881958
ratio: 0.65, loss: 3.0876190662384033
ratio: 0.7, loss: 3.087521553039551
ratio: 0.75, loss: 3.0876309871673584
ratio: 0.8, loss: 3.0883188247680664
ratio: 0.85, loss: 3.087419271469116
ratio: 0.9, loss: 3.0877151489257812
ratio: 0.95, loss: 3.087599039077759
ratio: 0.0, loss: 0.0170189607

Running AWQ...:  67%|██████▋   | 16/24 [01:59<00:59,  7.38s/it]

dict_keys(['self_attn.q_proj', 'self_attn.k_proj', 'self_attn.v_proj', 'self_attn.out_proj', 'fc1', 'fc2'])
dict_keys(['self_attn.q_proj', 'self_attn.k_proj', 'self_attn.v_proj', 'self_attn.out_proj', 'fc1', 'fc2'])
----------
ratio: 0.0, loss: 3.606783628463745
ratio: 0.05, loss: 3.6136813163757324
ratio: 0.1, loss: 3.618504285812378
ratio: 0.15, loss: 3.620058536529541
ratio: 0.2, loss: 3.6217942237854004
ratio: 0.25, loss: 3.624103546142578
ratio: 0.3, loss: 3.622575521469116
ratio: 0.35, loss: 3.623767852783203
ratio: 0.4, loss: 3.625082015991211
ratio: 0.45, loss: 3.6245815753936768
ratio: 0.5, loss: 3.6247239112854004
ratio: 0.55, loss: 3.6239705085754395
ratio: 0.6, loss: 3.6247060298919678
ratio: 0.65, loss: 3.6250128746032715
ratio: 0.7, loss: 3.6245810985565186
ratio: 0.75, loss: 3.624676465988159
ratio: 0.8, loss: 3.6246776580810547
ratio: 0.85, loss: 3.6248395442962646
ratio: 0.9, loss: 3.624539375305176
ratio: 0.95, loss: 3.6250083446502686
ratio: 0.0, loss: 0.024313105270

Running AWQ...:  71%|███████   | 17/24 [02:06<00:52,  7.43s/it]

dict_keys(['self_attn.q_proj', 'self_attn.k_proj', 'self_attn.v_proj', 'self_attn.out_proj', 'fc1', 'fc2'])
dict_keys(['self_attn.q_proj', 'self_attn.k_proj', 'self_attn.v_proj', 'self_attn.out_proj', 'fc1', 'fc2'])
----------
ratio: 0.0, loss: 4.098090171813965
ratio: 0.05, loss: 4.102348327636719
ratio: 0.1, loss: 4.105107307434082
ratio: 0.15, loss: 4.109041690826416
ratio: 0.2, loss: 4.113092422485352
ratio: 0.25, loss: 4.113900661468506
ratio: 0.3, loss: 4.113001346588135
ratio: 0.35, loss: 4.112967014312744
ratio: 0.4, loss: 4.114288330078125
ratio: 0.45, loss: 4.113400936126709
ratio: 0.5, loss: 4.113714218139648
ratio: 0.55, loss: 4.1133880615234375
ratio: 0.6, loss: 4.114016532897949
ratio: 0.65, loss: 4.114101886749268
ratio: 0.7, loss: 4.113901138305664
ratio: 0.75, loss: 4.1137566566467285
ratio: 0.8, loss: 4.113761901855469
ratio: 0.85, loss: 4.113729000091553
ratio: 0.9, loss: 4.113042831420898
ratio: 0.95, loss: 4.11354923248291
ratio: 0.0, loss: 0.03893859311938286
rati

Running AWQ...:  75%|███████▌  | 18/24 [02:14<00:44,  7.44s/it]

dict_keys(['self_attn.q_proj', 'self_attn.k_proj', 'self_attn.v_proj', 'self_attn.out_proj', 'fc1', 'fc2'])
dict_keys(['self_attn.q_proj', 'self_attn.k_proj', 'self_attn.v_proj', 'self_attn.out_proj', 'fc1', 'fc2'])
----------
ratio: 0.0, loss: 4.507205486297607
ratio: 0.05, loss: 4.5077691078186035
ratio: 0.1, loss: 4.5126633644104
ratio: 0.15, loss: 4.51412296295166
ratio: 0.2, loss: 4.514819145202637
ratio: 0.25, loss: 4.515303611755371
ratio: 0.3, loss: 4.516749382019043
ratio: 0.35, loss: 4.516157150268555
ratio: 0.4, loss: 4.517371654510498
ratio: 0.45, loss: 4.516531467437744
ratio: 0.5, loss: 4.517207145690918
ratio: 0.55, loss: 4.516160011291504
ratio: 0.6, loss: 4.51605224609375
ratio: 0.65, loss: 4.5159430503845215
ratio: 0.7, loss: 4.515417575836182
ratio: 0.75, loss: 4.51545524597168
ratio: 0.8, loss: 4.516058444976807
ratio: 0.85, loss: 4.5158209800720215
ratio: 0.9, loss: 4.515748500823975
ratio: 0.95, loss: 4.5158491134643555
ratio: 0.0, loss: 0.043424397706985474
ratio

Running AWQ...:  79%|███████▉  | 19/24 [02:22<00:38,  7.64s/it]

dict_keys(['self_attn.q_proj', 'self_attn.k_proj', 'self_attn.v_proj', 'self_attn.out_proj', 'fc1', 'fc2'])
dict_keys(['self_attn.q_proj', 'self_attn.k_proj', 'self_attn.v_proj', 'self_attn.out_proj', 'fc1', 'fc2'])
----------
ratio: 0.0, loss: 5.008797645568848
ratio: 0.05, loss: 5.009836196899414
ratio: 0.1, loss: 5.012081146240234
ratio: 0.15, loss: 5.012326240539551
ratio: 0.2, loss: 5.013529300689697
ratio: 0.25, loss: 5.014819622039795
ratio: 0.3, loss: 5.013433933258057
ratio: 0.35, loss: 5.013425827026367
ratio: 0.4, loss: 5.01279878616333
ratio: 0.45, loss: 5.013525485992432
ratio: 0.5, loss: 5.013399600982666
ratio: 0.55, loss: 5.013429641723633
ratio: 0.6, loss: 5.012959003448486
ratio: 0.65, loss: 5.012485980987549
ratio: 0.7, loss: 5.012461185455322
ratio: 0.75, loss: 5.012942790985107
ratio: 0.8, loss: 5.012975215911865
ratio: 0.85, loss: 5.013162136077881
ratio: 0.9, loss: 5.013697624206543
ratio: 0.95, loss: 5.013622283935547
ratio: 0.0, loss: 0.07589181512594223
ratio:

Running AWQ...:  83%|████████▎ | 20/24 [02:29<00:29,  7.48s/it]

dict_keys(['self_attn.q_proj', 'self_attn.k_proj', 'self_attn.v_proj', 'self_attn.out_proj', 'fc1', 'fc2'])
dict_keys(['self_attn.q_proj', 'self_attn.k_proj', 'self_attn.v_proj', 'self_attn.out_proj', 'fc1', 'fc2'])
----------
ratio: 0.0, loss: 5.2321906089782715
ratio: 0.05, loss: 5.233692169189453
ratio: 0.1, loss: 5.235561370849609
ratio: 0.15, loss: 5.236658573150635
ratio: 0.2, loss: 5.23853063583374
ratio: 0.25, loss: 5.237856864929199
ratio: 0.3, loss: 5.238380432128906
ratio: 0.35, loss: 5.238424301147461
ratio: 0.4, loss: 5.237058639526367
ratio: 0.45, loss: 5.236603736877441
ratio: 0.5, loss: 5.236937522888184
ratio: 0.55, loss: 5.236280918121338
ratio: 0.6, loss: 5.236105918884277
ratio: 0.65, loss: 5.236388206481934
ratio: 0.7, loss: 5.236429214477539
ratio: 0.75, loss: 5.236907958984375
ratio: 0.8, loss: 5.236187934875488
ratio: 0.85, loss: 5.236227512359619
ratio: 0.9, loss: 5.236375331878662
ratio: 0.95, loss: 5.236305236816406
ratio: 0.0, loss: 0.13064898550510406
ratio

Running AWQ...:  88%|████████▊ | 21/24 [02:36<00:22,  7.41s/it]

dict_keys(['self_attn.q_proj', 'self_attn.k_proj', 'self_attn.v_proj', 'self_attn.out_proj', 'fc1', 'fc2'])
dict_keys(['self_attn.q_proj', 'self_attn.k_proj', 'self_attn.v_proj', 'self_attn.out_proj', 'fc1', 'fc2'])
----------
ratio: 0.0, loss: 5.1417460441589355
ratio: 0.05, loss: 5.141580104827881
ratio: 0.1, loss: 5.142418384552002
ratio: 0.15, loss: 5.142065048217773
ratio: 0.2, loss: 5.141053199768066
ratio: 0.25, loss: 5.140204429626465
ratio: 0.3, loss: 5.140500068664551
ratio: 0.35, loss: 5.141408443450928
ratio: 0.4, loss: 5.140435218811035
ratio: 0.45, loss: 5.139959812164307
ratio: 0.5, loss: 5.139729976654053
ratio: 0.55, loss: 5.139832496643066
ratio: 0.6, loss: 5.1397809982299805
ratio: 0.65, loss: 5.13938570022583
ratio: 0.7, loss: 5.138613224029541
ratio: 0.75, loss: 5.139151096343994
ratio: 0.8, loss: 5.139272689819336
ratio: 0.85, loss: 5.139551162719727
ratio: 0.9, loss: 5.140227794647217
ratio: 0.95, loss: 5.139985084533691
ratio: 0.0, loss: 0.15806381404399872
rati

Running AWQ...:  92%|█████████▏| 22/24 [02:44<00:14,  7.43s/it]

dict_keys(['self_attn.q_proj', 'self_attn.k_proj', 'self_attn.v_proj', 'self_attn.out_proj', 'fc1', 'fc2'])
dict_keys(['self_attn.q_proj', 'self_attn.k_proj', 'self_attn.v_proj', 'self_attn.out_proj', 'fc1', 'fc2'])
----------
ratio: 0.0, loss: 4.91106653213501
ratio: 0.05, loss: 4.911135673522949
ratio: 0.1, loss: 4.911931037902832
ratio: 0.15, loss: 4.911045551300049
ratio: 0.2, loss: 4.911061763763428
ratio: 0.25, loss: 4.911036014556885
ratio: 0.3, loss: 4.909476280212402
ratio: 0.35, loss: 4.90964937210083
ratio: 0.4, loss: 4.909162521362305
ratio: 0.45, loss: 4.9094038009643555
ratio: 0.5, loss: 4.90944766998291
ratio: 0.55, loss: 4.909997940063477
ratio: 0.6, loss: 4.9101948738098145
ratio: 0.65, loss: 4.909687042236328
ratio: 0.7, loss: 4.909558296203613
ratio: 0.75, loss: 4.90934944152832
ratio: 0.8, loss: 4.908905982971191
ratio: 0.85, loss: 4.90927267074585
ratio: 0.9, loss: 4.90928316116333
ratio: 0.95, loss: 4.909463882446289
ratio: 0.0, loss: 0.2129235714673996
ratio: 0.0

Running AWQ...:  96%|█████████▌| 23/24 [02:53<00:07,  7.81s/it]

dict_keys(['self_attn.q_proj', 'self_attn.k_proj', 'self_attn.v_proj', 'self_attn.out_proj', 'fc1', 'fc2'])
dict_keys(['self_attn.q_proj', 'self_attn.k_proj', 'self_attn.v_proj', 'self_attn.out_proj', 'fc1', 'fc2'])
----------
ratio: 0.0, loss: 4.252490043640137
ratio: 0.05, loss: 4.252146244049072
ratio: 0.1, loss: 4.2516350746154785
ratio: 0.15, loss: 4.251548767089844
ratio: 0.2, loss: 4.252331256866455
ratio: 0.25, loss: 4.2508225440979
ratio: 0.3, loss: 4.251089096069336
ratio: 0.35, loss: 4.2515482902526855
ratio: 0.4, loss: 4.2508134841918945
ratio: 0.45, loss: 4.251100063323975
ratio: 0.5, loss: 4.250777721405029
ratio: 0.55, loss: 4.25013542175293
ratio: 0.6, loss: 4.250169277191162
ratio: 0.65, loss: 4.250665187835693
ratio: 0.7, loss: 4.250792503356934
ratio: 0.75, loss: 4.25083065032959
ratio: 0.8, loss: 4.250692367553711
ratio: 0.85, loss: 4.2504777908325195
ratio: 0.9, loss: 4.250931739807129
ratio: 0.95, loss: 4.251259803771973
ratio: 0.0, loss: 0.17429298162460327
ratio

Running AWQ...: 100%|██████████| 24/24 [03:00<00:00,  7.52s/it]


In [49]:
dump_awq = "awq_results.pt"
torch.save(awq_results, dump_awq)
print("AWQ results saved at", dump_awq)
#

AWQ results saved at awq_results.pt


In [50]:
### load awq

load_awq = "awq_results.pt"
awq_results = torch.load(load_awq, map_location="cpu")

In [51]:
q_config

{'zero_point': True,
 'q_group_size': 256,
 'w_n_bits': 4,
 'a_n_bits': 4,
 'act_quant': 'per_tensor'}

In [52]:

from transformers.models.opt.modeling_opt import OPTDecoderLayer
from awq.quantize.fake_quant_lester import quantize_opt_model

model = AutoModelForCausalLM.from_pretrained(
    model_path, torch_dtype=torch_dtype, device_map="auto"
)

# apply the AWQ results
apply_awq(model, awq_results)

model = quantize_opt_model(
    model,
    w_n_bits=q_config["w_n_bits"],
    a_n_bits=q_config["a_n_bits"],
    act_quant=q_config["act_quant"],
    group_size=q_config["q_group_size"],
)


In [53]:
torch.cuda.empty_cache()
model.cuda()

OPTForCausalLM(
  (model): OPTModel(
    (decoder): OPTDecoder(
      (embed_tokens): Embedding(50272, 2048, padding_idx=1)
      (embed_positions): OPTLearnedPositionalEmbedding(2050, 2048)
      (final_layer_norm): LayerNorm((2048,), eps=1e-05, elementwise_affine=True)
      (layers): ModuleList(
        (0-23): 24 x OPTDecoderLayer(
          (self_attn): OPTAttention(
            (k_proj): Linear(in_features=2048, out_features=2048, bias=True)
            (v_proj): Linear(in_features=2048, out_features=2048, bias=True)
            (q_proj): Linear(in_features=2048, out_features=2048, bias=True)
            (out_proj): Linear(in_features=2048, out_features=2048, bias=True)
          )
          (activation_fn): ReLU()
          (self_attn_layer_norm): LayerNorm((2048,), eps=1e-05, elementwise_affine=True)
          (fc1): QuantizedLinear()
          (fc2): QuantizedLinear()
          (final_layer_norm): LayerNorm((2048,), eps=1e-05, elementwise_affine=True)
        )
      )
    )
 

In [54]:
# Evaluate the model
model_perplexity = evaluate(model, tokenizer)
model_size = get_model_size(model, data_width=q_config["w_n_bits"], group_size=q_config["q_group_size"])
print(f"\nmodel perplexity: {model_perplexity:.2f}")
print(f"model size: {model_size/MiB:.2f} MiB")

evaluating...: 100%|██████████| 40/40 [00:21<00:00,  1.86it/s]


model perplexity: 23815.02
model size: 248.04 MiB





In [42]:

print("hi")

hi


## RANDOM STUFF

In [43]:
from awq.quantize.pre_quant_lester import get_blocks

In [44]:
for layer in get_blocks(model):
    # layer.fc1
    # layer.fc2
    # print("hi")
    break

In [45]:
from copy import deepcopy

x = deepcopy(layer)

In [46]:
x.fc1.weight.data[:] = 1.0
# x.fc1.weight.data

layer.fc1.weight.data

tensor([[ 3.1622e-04,  2.4802e-05, -1.9221e-04,  ...,  1.8563e-04,
          6.4971e-04,  2.0110e-04],
        [ 5.2574e-05,  2.9792e-04,  3.6802e-04,  ..., -6.5032e-04,
         -3.0992e-04, -1.0161e-04],
        [-2.0619e-04, -1.8504e-04,  1.9033e-04,  ..., -2.7781e-04,
          5.7877e-06,  3.2990e-04],
        ...,
        [ 3.4242e-04,  1.7121e-04, -1.6377e-04,  ...,  2.4482e-04,
         -1.8939e-04,  2.0787e-04],
        [-5.0239e-05,  1.1722e-04, -1.1722e-04,  ...,  1.6503e-04,
          1.0045e-04, -2.8701e-04],
        [-4.2769e-05,  0.0000e+00,  2.5234e-04,  ..., -1.2389e-05,
         -2.8908e-05,  9.9112e-05]], device='cuda:0')

In [47]:
isinstance(layer, OPTDecoderLayer)

True