In [1]:
%load_ext autoreload
%autoreload 2

In [2]:
import os

os.environ["CUDA_VISIBLE_DEVICES"] = "6"

In [3]:
import awq

In [4]:
import tqdm
import torch
from torch import nn
from transformers import AutoModelForCausalLM, AutoTokenizer
from datasets import load_dataset
from functools import partial
import gc

In [5]:
def evaluate(model, tokenizer):
    testenc = load_dataset("wikitext", "wikitext-2-raw-v1", split="test")
    testenc = tokenizer("\n\n".join(testenc["text"]), return_tensors="pt")

    testenc = testenc.input_ids.to(model.device)
    nsamples = 40
    model = model.eval()

    nlls = []
    for i in tqdm.tqdm(range(nsamples), desc="evaluating..."):
        batch = testenc[:, (i * 2048) : ((i + 1) * 2048)].to(model.device)
        with torch.no_grad():
            lm_logits = model(batch).logits
        shift_logits = lm_logits[:, :-1, :].contiguous().float()
        shift_labels = testenc[:, (i * 2048) : ((i + 1) * 2048)][:, 1:]
        loss_fct = nn.CrossEntropyLoss()
        loss = loss_fct(
            shift_logits.view(-1, shift_logits.size(-1)), shift_labels.view(-1)
        )
        neg_log_likelihood = loss.float() * 2048
        nlls.append(neg_log_likelihood)

    return torch.exp(torch.stack(nlls).sum() / (nsamples * 2048))

In [6]:
def get_model_size(model: nn.Module, data_width=16, group_size=-1):

    if group_size != -1:
        data_width += (16 + 4) / group_size

    num_elements = 0
    for param in model.parameters():
        num_elements += param.numel()
    return num_elements * data_width


Byte = 8
KiB = 1024 * Byte
MiB = 1024 * KiB
GiB = 1024 * MiB

In [7]:
model_path = "facebook/opt-1.3b"

# model_path = "facebook/opt-13b"

tokenizer = AutoTokenizer.from_pretrained(model_path, use_fast=False)



In [8]:
original_model_n_bits = 32
torch_dtype = torch.float16 if original_model_n_bits == 16 else torch.float32

model = AutoModelForCausalLM.from_pretrained(
    model_path, torch_dtype=torch_dtype, device_map="auto"
)



In [9]:
# Evaluate the model
model_perplexity = evaluate(model, tokenizer)
model_size = get_model_size(model, data_width=original_model_n_bits, group_size=128)

### Print the results
print(f"\nmodel perplexity: {model_perplexity:.2f}")
print(f"model size: {model_size/MiB:.2f} MiB")

evaluating...: 100%|██████████| 40/40 [00:20<00:00,  1.93it/s]



model perplexity: 14.47
model size: 5043.73 MiB


In [16]:
from awq.quantize.pre_quant_lester import run_awq, apply_awq
from typing import Literal

ActQuantType = Literal["per_token", "per_tensor", "none"]

q_config = {
    "zero_point": True,  # by default True
    "q_group_size": 128,  # whether to use group quantization
    "w_n_bits": 8,
    "a_n_bits": 8,
    "act_quant": "per_tensor",
}

model = AutoModelForCausalLM.from_pretrained(
    model_path, torch_dtype=torch_dtype, device_map="auto"
)

awq_results = run_awq(
    model,
    tokenizer,
    w_bit=q_config["w_n_bits"],
    q_config=q_config,
    n_samples=128,
    seqlen=512,
)

Repo card metadata block was not found. Setting CardData to empty.


 * Split into 60 blocks


Running AWQ...:   0%|          | 0/24 [00:00<?, ?it/s]

dict_keys(['self_attn.q_proj', 'self_attn.k_proj', 'self_attn.v_proj', 'self_attn.out_proj', 'fc1', 'fc2'])
dict_keys(['self_attn.q_proj', 'self_attn.k_proj', 'self_attn.v_proj', 'self_attn.out_proj', 'fc1', 'fc2'])
----------
ratio: 0.0, loss: 0.00427306117489934
ratio: 0.05, loss: 0.004263454116880894
ratio: 0.1, loss: 0.004256470128893852
ratio: 0.15, loss: 0.004252363927662373
ratio: 0.2, loss: 0.0042498731054365635
ratio: 0.25, loss: 0.004248196259140968
ratio: 0.3, loss: 0.004247210454195738
ratio: 0.35, loss: 0.004246462602168322
ratio: 0.4, loss: 0.004246068187057972
ratio: 0.45, loss: 0.0042457422241568565
ratio: 0.5, loss: 0.004245585296303034
ratio: 0.55, loss: 0.004245635122060776
ratio: 0.6, loss: 0.004245671909302473
ratio: 0.65, loss: 0.00424604257568717
ratio: 0.7, loss: 0.004246539901942015
ratio: 0.75, loss: 0.0042471387423574924
ratio: 0.8, loss: 0.004248109646141529
ratio: 0.85, loss: 0.004249550402164459
ratio: 0.9, loss: 0.004252143204212189
ratio: 0.95, loss: 0.0

Running AWQ...:   4%|▍         | 1/24 [00:06<02:39,  6.95s/it]

dict_keys(['self_attn.q_proj', 'self_attn.k_proj', 'self_attn.v_proj', 'self_attn.out_proj', 'fc1', 'fc2'])
dict_keys(['self_attn.q_proj', 'self_attn.k_proj', 'self_attn.v_proj', 'self_attn.out_proj', 'fc1', 'fc2'])
----------
ratio: 0.0, loss: 0.009539443999528885
ratio: 0.05, loss: 0.00952586904168129
ratio: 0.1, loss: 0.009515193291008472
ratio: 0.15, loss: 0.009507773444056511
ratio: 0.2, loss: 0.009505671449005604
ratio: 0.25, loss: 0.009502706117928028
ratio: 0.3, loss: 0.009502350352704525
ratio: 0.35, loss: 0.009501579217612743
ratio: 0.4, loss: 0.009501717984676361
ratio: 0.45, loss: 0.009500639513134956
ratio: 0.5, loss: 0.00950048677623272
ratio: 0.55, loss: 0.009500422514975071
ratio: 0.6, loss: 0.009501753374934196
ratio: 0.65, loss: 0.009501888416707516
ratio: 0.7, loss: 0.009503398090600967
ratio: 0.75, loss: 0.009504575282335281
ratio: 0.8, loss: 0.009507945738732815
ratio: 0.85, loss: 0.009513082914054394
ratio: 0.9, loss: 0.009520062245428562
ratio: 0.95, loss: 0.0095

Running AWQ...:   8%|▊         | 2/24 [00:14<02:42,  7.38s/it]

dict_keys(['self_attn.q_proj', 'self_attn.k_proj', 'self_attn.v_proj', 'self_attn.out_proj', 'fc1', 'fc2'])
dict_keys(['self_attn.q_proj', 'self_attn.k_proj', 'self_attn.v_proj', 'self_attn.out_proj', 'fc1', 'fc2'])
----------
ratio: 0.0, loss: 0.014652990736067295
ratio: 0.05, loss: 0.014642191119492054
ratio: 0.1, loss: 0.014632843434810638
ratio: 0.15, loss: 0.014625567942857742
ratio: 0.2, loss: 0.014622591435909271
ratio: 0.25, loss: 0.01461766753345728
ratio: 0.3, loss: 0.014620897360146046
ratio: 0.35, loss: 0.014618327841162682
ratio: 0.4, loss: 0.014618946239352226
ratio: 0.45, loss: 0.014619222842156887
ratio: 0.5, loss: 0.01461872924119234
ratio: 0.55, loss: 0.014617585577070713
ratio: 0.6, loss: 0.014619668014347553
ratio: 0.65, loss: 0.014619214460253716
ratio: 0.7, loss: 0.014620248228311539
ratio: 0.75, loss: 0.0146217355504632
ratio: 0.8, loss: 0.014621896669268608
ratio: 0.85, loss: 0.014626724645495415
ratio: 0.9, loss: 0.014631900005042553
ratio: 0.95, loss: 0.014641

Running AWQ...:  12%|█▎        | 3/24 [00:23<02:47,  7.98s/it]

dict_keys(['self_attn.q_proj', 'self_attn.k_proj', 'self_attn.v_proj', 'self_attn.out_proj', 'fc1', 'fc2'])
dict_keys(['self_attn.q_proj', 'self_attn.k_proj', 'self_attn.v_proj', 'self_attn.out_proj', 'fc1', 'fc2'])
----------
ratio: 0.0, loss: 0.01361881010234356
ratio: 0.05, loss: 0.013607820495963097
ratio: 0.1, loss: 0.013600076548755169
ratio: 0.15, loss: 0.013592327944934368
ratio: 0.2, loss: 0.0135918278247118
ratio: 0.25, loss: 0.013588282279670238
ratio: 0.3, loss: 0.013587272725999355
ratio: 0.35, loss: 0.01358592789620161
ratio: 0.4, loss: 0.013586241751909256
ratio: 0.45, loss: 0.013586432673037052
ratio: 0.5, loss: 0.013586614280939102
ratio: 0.55, loss: 0.013586084358394146
ratio: 0.6, loss: 0.013586231507360935
ratio: 0.65, loss: 0.013585977256298065
ratio: 0.7, loss: 0.01358698308467865
ratio: 0.75, loss: 0.013587454333901405
ratio: 0.8, loss: 0.013589572161436081
ratio: 0.85, loss: 0.013590722344815731
ratio: 0.9, loss: 0.013592743314802647
ratio: 0.95, loss: 0.0135989

Running AWQ...:  17%|█▋        | 4/24 [00:30<02:35,  7.77s/it]

dict_keys(['self_attn.q_proj', 'self_attn.k_proj', 'self_attn.v_proj', 'self_attn.out_proj', 'fc1', 'fc2'])
dict_keys(['self_attn.q_proj', 'self_attn.k_proj', 'self_attn.v_proj', 'self_attn.out_proj', 'fc1', 'fc2'])
----------
ratio: 0.0, loss: 0.010516151785850525
ratio: 0.05, loss: 0.010507465340197086
ratio: 0.1, loss: 0.010499853640794754
ratio: 0.15, loss: 0.010496027767658234
ratio: 0.2, loss: 0.010492024943232536
ratio: 0.25, loss: 0.010491888038814068
ratio: 0.3, loss: 0.010489150881767273
ratio: 0.35, loss: 0.010489038191735744
ratio: 0.4, loss: 0.01048866007477045
ratio: 0.45, loss: 0.010488361120223999
ratio: 0.5, loss: 0.010488242842257023
ratio: 0.55, loss: 0.01048858743160963
ratio: 0.6, loss: 0.010488533414900303
ratio: 0.65, loss: 0.010488837957382202
ratio: 0.7, loss: 0.010489342734217644
ratio: 0.75, loss: 0.010489837266504765
ratio: 0.8, loss: 0.010491437278687954
ratio: 0.85, loss: 0.010492819361388683
ratio: 0.9, loss: 0.010495488531887531
ratio: 0.95, loss: 0.0104

Running AWQ...:  21%|██        | 5/24 [00:39<02:33,  8.08s/it]

dict_keys(['self_attn.q_proj', 'self_attn.k_proj', 'self_attn.v_proj', 'self_attn.out_proj', 'fc1', 'fc2'])
dict_keys(['self_attn.q_proj', 'self_attn.k_proj', 'self_attn.v_proj', 'self_attn.out_proj', 'fc1', 'fc2'])
----------
ratio: 0.0, loss: 0.008515790104866028
ratio: 0.05, loss: 0.00850585661828518
ratio: 0.1, loss: 0.008499841205775738
ratio: 0.15, loss: 0.008494972251355648
ratio: 0.2, loss: 0.008492840453982353
ratio: 0.25, loss: 0.008491840213537216
ratio: 0.3, loss: 0.008491049520671368
ratio: 0.35, loss: 0.008490134961903095
ratio: 0.4, loss: 0.0084895770996809
ratio: 0.45, loss: 0.008489486761391163
ratio: 0.5, loss: 0.008489119820296764
ratio: 0.55, loss: 0.008489210158586502
ratio: 0.6, loss: 0.008489369414746761
ratio: 0.65, loss: 0.008489575237035751
ratio: 0.7, loss: 0.00849023275077343
ratio: 0.75, loss: 0.008490930311381817
ratio: 0.8, loss: 0.00849192962050438
ratio: 0.85, loss: 0.008493592962622643
ratio: 0.9, loss: 0.008495701476931572
ratio: 0.95, loss: 0.0084998

Running AWQ...:  25%|██▌       | 6/24 [00:47<02:28,  8.25s/it]

dict_keys(['self_attn.q_proj', 'self_attn.k_proj', 'self_attn.v_proj', 'self_attn.out_proj', 'fc1', 'fc2'])
dict_keys(['self_attn.q_proj', 'self_attn.k_proj', 'self_attn.v_proj', 'self_attn.out_proj', 'fc1', 'fc2'])
----------
ratio: 0.0, loss: 0.009074023924767971
ratio: 0.05, loss: 0.009063711389899254
ratio: 0.1, loss: 0.00905846431851387
ratio: 0.15, loss: 0.00905437022447586
ratio: 0.2, loss: 0.009052379988133907
ratio: 0.25, loss: 0.009050266817212105
ratio: 0.3, loss: 0.009049632586538792
ratio: 0.35, loss: 0.009049005806446075
ratio: 0.4, loss: 0.009048647247254848
ratio: 0.45, loss: 0.009048274718225002
ratio: 0.5, loss: 0.009047935716807842
ratio: 0.55, loss: 0.00904820766299963
ratio: 0.6, loss: 0.00904824398458004
ratio: 0.65, loss: 0.009048684500157833
ratio: 0.7, loss: 0.009048941545188427
ratio: 0.75, loss: 0.009049803949892521
ratio: 0.8, loss: 0.009050886146724224
ratio: 0.85, loss: 0.009052267298102379
ratio: 0.9, loss: 0.009054504334926605
ratio: 0.95, loss: 0.009058

Running AWQ...:  29%|██▉       | 7/24 [00:56<02:19,  8.23s/it]

dict_keys(['self_attn.q_proj', 'self_attn.k_proj', 'self_attn.v_proj', 'self_attn.out_proj', 'fc1', 'fc2'])
dict_keys(['self_attn.q_proj', 'self_attn.k_proj', 'self_attn.v_proj', 'self_attn.out_proj', 'fc1', 'fc2'])
----------
ratio: 0.0, loss: 0.009939935058355331
ratio: 0.05, loss: 0.00993150845170021
ratio: 0.1, loss: 0.009926311671733856
ratio: 0.15, loss: 0.009923464618623257
ratio: 0.2, loss: 0.009920240379869938
ratio: 0.25, loss: 0.009918536990880966
ratio: 0.3, loss: 0.009917638264596462
ratio: 0.35, loss: 0.009916557930409908
ratio: 0.4, loss: 0.009916451759636402
ratio: 0.45, loss: 0.009915970265865326
ratio: 0.5, loss: 0.00991566851735115
ratio: 0.55, loss: 0.009915735572576523
ratio: 0.6, loss: 0.009916026145219803
ratio: 0.65, loss: 0.009916340000927448
ratio: 0.7, loss: 0.00991677027195692
ratio: 0.75, loss: 0.009917688556015491
ratio: 0.8, loss: 0.00991878006607294
ratio: 0.85, loss: 0.009920397773385048
ratio: 0.9, loss: 0.009922701865434647
ratio: 0.95, loss: 0.009926

Running AWQ...:  33%|███▎      | 8/24 [01:04<02:13,  8.37s/it]

dict_keys(['self_attn.q_proj', 'self_attn.k_proj', 'self_attn.v_proj', 'self_attn.out_proj', 'fc1', 'fc2'])
dict_keys(['self_attn.q_proj', 'self_attn.k_proj', 'self_attn.v_proj', 'self_attn.out_proj', 'fc1', 'fc2'])
----------
ratio: 0.0, loss: 0.010508631356060505
ratio: 0.05, loss: 0.010501382872462273
ratio: 0.1, loss: 0.010495446622371674
ratio: 0.15, loss: 0.010492194443941116
ratio: 0.2, loss: 0.01048994529992342
ratio: 0.25, loss: 0.01048749964684248
ratio: 0.3, loss: 0.010486633516848087
ratio: 0.35, loss: 0.010485555045306683
ratio: 0.4, loss: 0.010485399514436722
ratio: 0.45, loss: 0.010484931990504265
ratio: 0.5, loss: 0.010484803467988968
ratio: 0.55, loss: 0.010484615340828896
ratio: 0.6, loss: 0.010484886355698109
ratio: 0.65, loss: 0.010485146194696426
ratio: 0.7, loss: 0.010485935024917126
ratio: 0.75, loss: 0.01048693060874939
ratio: 0.8, loss: 0.010488131083548069
ratio: 0.85, loss: 0.010490022599697113
ratio: 0.9, loss: 0.010492988862097263
ratio: 0.95, loss: 0.01049

Running AWQ...:  38%|███▊      | 9/24 [01:13<02:05,  8.34s/it]

dict_keys(['self_attn.q_proj', 'self_attn.k_proj', 'self_attn.v_proj', 'self_attn.out_proj', 'fc1', 'fc2'])
dict_keys(['self_attn.q_proj', 'self_attn.k_proj', 'self_attn.v_proj', 'self_attn.out_proj', 'fc1', 'fc2'])
----------
ratio: 0.0, loss: 0.010445923544466496
ratio: 0.05, loss: 0.010437730699777603
ratio: 0.1, loss: 0.010432631708681583
ratio: 0.15, loss: 0.0104296263307333
ratio: 0.2, loss: 0.010426882654428482
ratio: 0.25, loss: 0.010424953885376453
ratio: 0.3, loss: 0.010424028150737286
ratio: 0.35, loss: 0.010423323139548302
ratio: 0.4, loss: 0.01042274571955204
ratio: 0.45, loss: 0.010422726161777973
ratio: 0.5, loss: 0.01042228378355503
ratio: 0.55, loss: 0.010422253981232643
ratio: 0.6, loss: 0.010422606952488422
ratio: 0.65, loss: 0.010423162020742893
ratio: 0.7, loss: 0.010423710569739342
ratio: 0.75, loss: 0.010424463078379631
ratio: 0.8, loss: 0.010425860062241554
ratio: 0.85, loss: 0.010427786037325859
ratio: 0.9, loss: 0.010430416092276573
ratio: 0.95, loss: 0.010434

Running AWQ...:  42%|████▏     | 10/24 [01:22<01:59,  8.52s/it]

dict_keys(['self_attn.q_proj', 'self_attn.k_proj', 'self_attn.v_proj', 'self_attn.out_proj', 'fc1', 'fc2'])
dict_keys(['self_attn.q_proj', 'self_attn.k_proj', 'self_attn.v_proj', 'self_attn.out_proj', 'fc1', 'fc2'])
----------
ratio: 0.0, loss: 0.010541156865656376
ratio: 0.05, loss: 0.010535016655921936
ratio: 0.1, loss: 0.010530130937695503
ratio: 0.15, loss: 0.010527344420552254
ratio: 0.2, loss: 0.010524490848183632
ratio: 0.25, loss: 0.010523764416575432
ratio: 0.3, loss: 0.010522604919970036
ratio: 0.35, loss: 0.010521795600652695
ratio: 0.4, loss: 0.010521509684622288
ratio: 0.45, loss: 0.010521374642848969
ratio: 0.5, loss: 0.010521285235881805
ratio: 0.55, loss: 0.010521255433559418
ratio: 0.6, loss: 0.010521446354687214
ratio: 0.65, loss: 0.0105219054967165
ratio: 0.7, loss: 0.010522552765905857
ratio: 0.75, loss: 0.010523269884288311
ratio: 0.8, loss: 0.010524557903409004
ratio: 0.85, loss: 0.010526481084525585
ratio: 0.9, loss: 0.010528523474931717
ratio: 0.95, loss: 0.0105

Running AWQ...:  46%|████▌     | 11/24 [01:30<01:49,  8.41s/it]

dict_keys(['self_attn.q_proj', 'self_attn.k_proj', 'self_attn.v_proj', 'self_attn.out_proj', 'fc1', 'fc2'])
dict_keys(['self_attn.q_proj', 'self_attn.k_proj', 'self_attn.v_proj', 'self_attn.out_proj', 'fc1', 'fc2'])
----------
ratio: 0.0, loss: 0.010311237536370754
ratio: 0.05, loss: 0.01030502188950777
ratio: 0.1, loss: 0.0103019243106246
ratio: 0.15, loss: 0.010299373418092728
ratio: 0.2, loss: 0.010297825559973717
ratio: 0.25, loss: 0.010296612977981567
ratio: 0.3, loss: 0.010295433923602104
ratio: 0.35, loss: 0.010294800624251366
ratio: 0.4, loss: 0.010294465348124504
ratio: 0.45, loss: 0.010294131003320217
ratio: 0.5, loss: 0.01029418408870697
ratio: 0.55, loss: 0.010293995030224323
ratio: 0.6, loss: 0.010294252075254917
ratio: 0.65, loss: 0.01029467023909092
ratio: 0.7, loss: 0.010295186191797256
ratio: 0.75, loss: 0.010296054184436798
ratio: 0.8, loss: 0.010297233238816261
ratio: 0.85, loss: 0.010299010202288628
ratio: 0.9, loss: 0.010301008820533752
ratio: 0.95, loss: 0.0103039

Running AWQ...:  50%|█████     | 12/24 [01:38<01:42,  8.52s/it]

dict_keys(['self_attn.q_proj', 'self_attn.k_proj', 'self_attn.v_proj', 'self_attn.out_proj', 'fc1', 'fc2'])
dict_keys(['self_attn.q_proj', 'self_attn.k_proj', 'self_attn.v_proj', 'self_attn.out_proj', 'fc1', 'fc2'])
----------
ratio: 0.0, loss: 0.009674214757978916
ratio: 0.05, loss: 0.009670856408774853
ratio: 0.1, loss: 0.009667651727795601
ratio: 0.15, loss: 0.009666044265031815
ratio: 0.2, loss: 0.009664404205977917
ratio: 0.25, loss: 0.009663206525146961
ratio: 0.3, loss: 0.00966224167495966
ratio: 0.35, loss: 0.009662067517638206
ratio: 0.4, loss: 0.00966189056634903
ratio: 0.45, loss: 0.009661446325480938
ratio: 0.5, loss: 0.009661777876317501
ratio: 0.55, loss: 0.00966172106564045
ratio: 0.6, loss: 0.009661879390478134
ratio: 0.65, loss: 0.009662142023444176
ratio: 0.7, loss: 0.009662688709795475
ratio: 0.75, loss: 0.009663743898272514
ratio: 0.8, loss: 0.00966496393084526
ratio: 0.85, loss: 0.009666632860898972
ratio: 0.9, loss: 0.00966867245733738
ratio: 0.95, loss: 0.0096721

Running AWQ...:  54%|█████▍    | 13/24 [01:47<01:33,  8.52s/it]

dict_keys(['self_attn.q_proj', 'self_attn.k_proj', 'self_attn.v_proj', 'self_attn.out_proj', 'fc1', 'fc2'])
dict_keys(['self_attn.q_proj', 'self_attn.k_proj', 'self_attn.v_proj', 'self_attn.out_proj', 'fc1', 'fc2'])
----------
ratio: 0.0, loss: 0.009752225130796432
ratio: 0.05, loss: 0.009748426266014576
ratio: 0.1, loss: 0.009746263734996319
ratio: 0.15, loss: 0.009744052775204182
ratio: 0.2, loss: 0.009742810390889645
ratio: 0.25, loss: 0.009741677902638912
ratio: 0.3, loss: 0.009741328656673431
ratio: 0.35, loss: 0.009740639477968216
ratio: 0.4, loss: 0.009740132838487625
ratio: 0.45, loss: 0.009740123525261879
ratio: 0.5, loss: 0.009740306064486504
ratio: 0.55, loss: 0.009740377776324749
ratio: 0.6, loss: 0.009740553796291351
ratio: 0.65, loss: 0.009741109795868397
ratio: 0.7, loss: 0.009741701185703278
ratio: 0.75, loss: 0.009742778725922108
ratio: 0.8, loss: 0.009743751958012581
ratio: 0.85, loss: 0.009745660237967968
ratio: 0.9, loss: 0.009748024865984917
ratio: 0.95, loss: 0.00

Running AWQ...:  58%|█████▊    | 14/24 [01:55<01:24,  8.46s/it]

dict_keys(['self_attn.q_proj', 'self_attn.k_proj', 'self_attn.v_proj', 'self_attn.out_proj', 'fc1', 'fc2'])
dict_keys(['self_attn.q_proj', 'self_attn.k_proj', 'self_attn.v_proj', 'self_attn.out_proj', 'fc1', 'fc2'])
----------
ratio: 0.0, loss: 0.00994689017534256
ratio: 0.05, loss: 0.009944210760295391
ratio: 0.1, loss: 0.009941477328538895
ratio: 0.15, loss: 0.009939990006387234
ratio: 0.2, loss: 0.009939001873135567
ratio: 0.25, loss: 0.009938289411365986
ratio: 0.3, loss: 0.009938099421560764
ratio: 0.35, loss: 0.009937197901308537
ratio: 0.4, loss: 0.009936759248375893
ratio: 0.45, loss: 0.009936757385730743
ratio: 0.5, loss: 0.009936771355569363
ratio: 0.55, loss: 0.009937000460922718
ratio: 0.6, loss: 0.009937342256307602
ratio: 0.65, loss: 0.009937628172338009
ratio: 0.7, loss: 0.0099385567009449
ratio: 0.75, loss: 0.00993956346064806
ratio: 0.8, loss: 0.00994087290018797
ratio: 0.85, loss: 0.009942512959241867
ratio: 0.9, loss: 0.009944835677742958
ratio: 0.95, loss: 0.0099481

Running AWQ...:  62%|██████▎   | 15/24 [02:04<01:16,  8.48s/it]

dict_keys(['self_attn.q_proj', 'self_attn.k_proj', 'self_attn.v_proj', 'self_attn.out_proj', 'fc1', 'fc2'])
dict_keys(['self_attn.q_proj', 'self_attn.k_proj', 'self_attn.v_proj', 'self_attn.out_proj', 'fc1', 'fc2'])
----------
ratio: 0.0, loss: 0.009998282417654991
ratio: 0.05, loss: 0.009996134787797928
ratio: 0.1, loss: 0.009993434883654118
ratio: 0.15, loss: 0.009992463514208794
ratio: 0.2, loss: 0.00999155268073082
ratio: 0.25, loss: 0.00999070331454277
ratio: 0.3, loss: 0.009990029036998749
ratio: 0.35, loss: 0.009989982470870018
ratio: 0.4, loss: 0.00998988188803196
ratio: 0.45, loss: 0.00998962763696909
ratio: 0.5, loss: 0.00998956710100174
ratio: 0.55, loss: 0.009989466518163681
ratio: 0.6, loss: 0.009989796206355095
ratio: 0.65, loss: 0.009990474209189415
ratio: 0.7, loss: 0.00999077782034874
ratio: 0.75, loss: 0.009991729632019997
ratio: 0.8, loss: 0.009992842562496662
ratio: 0.85, loss: 0.0099942646920681
ratio: 0.9, loss: 0.009996245615184307
ratio: 0.95, loss: 0.0099985357

Running AWQ...:  67%|██████▋   | 16/24 [02:12<01:06,  8.35s/it]

dict_keys(['self_attn.q_proj', 'self_attn.k_proj', 'self_attn.v_proj', 'self_attn.out_proj', 'fc1', 'fc2'])
dict_keys(['self_attn.q_proj', 'self_attn.k_proj', 'self_attn.v_proj', 'self_attn.out_proj', 'fc1', 'fc2'])
----------
ratio: 0.0, loss: 0.010620676912367344
ratio: 0.05, loss: 0.010621014051139355
ratio: 0.1, loss: 0.010618721134960651
ratio: 0.15, loss: 0.010617048479616642
ratio: 0.2, loss: 0.010615921579301357
ratio: 0.25, loss: 0.010615384206175804
ratio: 0.3, loss: 0.010614943690598011
ratio: 0.35, loss: 0.010614936240017414
ratio: 0.4, loss: 0.010614512488245964
ratio: 0.45, loss: 0.010614358820021152
ratio: 0.5, loss: 0.010614380240440369
ratio: 0.55, loss: 0.010614723898470402
ratio: 0.6, loss: 0.010615031234920025
ratio: 0.65, loss: 0.010615143924951553
ratio: 0.7, loss: 0.010615858249366283
ratio: 0.75, loss: 0.010616443119943142
ratio: 0.8, loss: 0.010617315769195557
ratio: 0.85, loss: 0.010618542321026325
ratio: 0.9, loss: 0.010620047338306904
ratio: 0.95, loss: 0.01

Running AWQ...:  71%|███████   | 17/24 [02:21<00:59,  8.53s/it]

dict_keys(['self_attn.q_proj', 'self_attn.k_proj', 'self_attn.v_proj', 'self_attn.out_proj', 'fc1', 'fc2'])
dict_keys(['self_attn.q_proj', 'self_attn.k_proj', 'self_attn.v_proj', 'self_attn.out_proj', 'fc1', 'fc2'])
----------
ratio: 0.0, loss: 0.011285732500255108
ratio: 0.05, loss: 0.011285433545708656
ratio: 0.1, loss: 0.01128321886062622
ratio: 0.15, loss: 0.01128286961466074
ratio: 0.2, loss: 0.011282328516244888
ratio: 0.25, loss: 0.011281279847025871
ratio: 0.3, loss: 0.011280934326350689
ratio: 0.35, loss: 0.011281007900834084
ratio: 0.4, loss: 0.01128073874861002
ratio: 0.45, loss: 0.011280602775514126
ratio: 0.5, loss: 0.011280574835836887
ratio: 0.55, loss: 0.011280765756964684
ratio: 0.6, loss: 0.011280903592705727
ratio: 0.65, loss: 0.011281179264187813
ratio: 0.7, loss: 0.011281580664217472
ratio: 0.75, loss: 0.011281857267022133
ratio: 0.8, loss: 0.011282515712082386
ratio: 0.85, loss: 0.01128337997943163
ratio: 0.9, loss: 0.011284596286714077
ratio: 0.95, loss: 0.011285

Running AWQ...:  75%|███████▌  | 18/24 [02:29<00:51,  8.50s/it]

dict_keys(['self_attn.q_proj', 'self_attn.k_proj', 'self_attn.v_proj', 'self_attn.out_proj', 'fc1', 'fc2'])
dict_keys(['self_attn.q_proj', 'self_attn.k_proj', 'self_attn.v_proj', 'self_attn.out_proj', 'fc1', 'fc2'])
----------
ratio: 0.0, loss: 0.011849013157188892
ratio: 0.05, loss: 0.011849025264382362
ratio: 0.1, loss: 0.011847290210425854
ratio: 0.15, loss: 0.011847369372844696
ratio: 0.2, loss: 0.011845954693853855
ratio: 0.25, loss: 0.011845551431179047
ratio: 0.3, loss: 0.011845500208437443
ratio: 0.35, loss: 0.011844991706311703
ratio: 0.4, loss: 0.011845432221889496
ratio: 0.45, loss: 0.011844983324408531
ratio: 0.5, loss: 0.011844474822282791
ratio: 0.55, loss: 0.011844661086797714
ratio: 0.6, loss: 0.01184503361582756
ratio: 0.65, loss: 0.011845188215374947
ratio: 0.7, loss: 0.011845563538372517
ratio: 0.75, loss: 0.011845903471112251
ratio: 0.8, loss: 0.011846469715237617
ratio: 0.85, loss: 0.011846876703202724
ratio: 0.9, loss: 0.01184744480997324
ratio: 0.95, loss: 0.0118

Running AWQ...:  79%|███████▉  | 19/24 [02:38<00:42,  8.46s/it]

dict_keys(['self_attn.q_proj', 'self_attn.k_proj', 'self_attn.v_proj', 'self_attn.out_proj', 'fc1', 'fc2'])
dict_keys(['self_attn.q_proj', 'self_attn.k_proj', 'self_attn.v_proj', 'self_attn.out_proj', 'fc1', 'fc2'])
----------
ratio: 0.0, loss: 0.012724155560135841
ratio: 0.05, loss: 0.012722507119178772
ratio: 0.1, loss: 0.012722522020339966
ratio: 0.15, loss: 0.012721046805381775
ratio: 0.2, loss: 0.012720970436930656
ratio: 0.25, loss: 0.012720184400677681
ratio: 0.3, loss: 0.01272035762667656
ratio: 0.35, loss: 0.012719481252133846
ratio: 0.4, loss: 0.012719560414552689
ratio: 0.45, loss: 0.012719854712486267
ratio: 0.5, loss: 0.01271981280297041
ratio: 0.55, loss: 0.01271980069577694
ratio: 0.6, loss: 0.012719874270260334
ratio: 0.65, loss: 0.012720128521323204
ratio: 0.7, loss: 0.012720244005322456
ratio: 0.75, loss: 0.012720111757516861
ratio: 0.8, loss: 0.012720695696771145
ratio: 0.85, loss: 0.012721067294478416
ratio: 0.9, loss: 0.012721851468086243
ratio: 0.95, loss: 0.01272

Running AWQ...:  83%|████████▎ | 20/24 [02:46<00:33,  8.35s/it]

dict_keys(['self_attn.q_proj', 'self_attn.k_proj', 'self_attn.v_proj', 'self_attn.out_proj', 'fc1', 'fc2'])
dict_keys(['self_attn.q_proj', 'self_attn.k_proj', 'self_attn.v_proj', 'self_attn.out_proj', 'fc1', 'fc2'])
----------
ratio: 0.0, loss: 0.013138605281710625
ratio: 0.05, loss: 0.013138435781002045
ratio: 0.1, loss: 0.013137154281139374
ratio: 0.15, loss: 0.013136611320078373
ratio: 0.2, loss: 0.013137019239366055
ratio: 0.25, loss: 0.01313579361885786
ratio: 0.3, loss: 0.013136031106114388
ratio: 0.35, loss: 0.013135704211890697
ratio: 0.4, loss: 0.01313498243689537
ratio: 0.45, loss: 0.013135014101862907
ratio: 0.5, loss: 0.013135211542248726
ratio: 0.55, loss: 0.013135004788637161
ratio: 0.6, loss: 0.013135101646184921
ratio: 0.65, loss: 0.013135249726474285
ratio: 0.7, loss: 0.013135562650859356
ratio: 0.75, loss: 0.013135681860148907
ratio: 0.8, loss: 0.013135869987308979
ratio: 0.85, loss: 0.013136335648596287
ratio: 0.9, loss: 0.013136751018464565
ratio: 0.95, loss: 0.0131

Running AWQ...:  88%|████████▊ | 21/24 [02:54<00:25,  8.45s/it]

dict_keys(['self_attn.q_proj', 'self_attn.k_proj', 'self_attn.v_proj', 'self_attn.out_proj', 'fc1', 'fc2'])
dict_keys(['self_attn.q_proj', 'self_attn.k_proj', 'self_attn.v_proj', 'self_attn.out_proj', 'fc1', 'fc2'])
----------
ratio: 0.0, loss: 0.013328366912901402
ratio: 0.05, loss: 0.013327497988939285
ratio: 0.1, loss: 0.013326018117368221
ratio: 0.15, loss: 0.013326751999557018
ratio: 0.2, loss: 0.013326087035238743
ratio: 0.25, loss: 0.01332535594701767
ratio: 0.3, loss: 0.013325346633791924
ratio: 0.35, loss: 0.013325224630534649
ratio: 0.4, loss: 0.013325215317308903
ratio: 0.45, loss: 0.013324769213795662
ratio: 0.5, loss: 0.01332525722682476
ratio: 0.55, loss: 0.013325068168342113
ratio: 0.6, loss: 0.013325002044439316
ratio: 0.65, loss: 0.013325396925210953
ratio: 0.7, loss: 0.013325324282050133
ratio: 0.75, loss: 0.013325356878340244
ratio: 0.8, loss: 0.013325857929885387
ratio: 0.85, loss: 0.013326325453817844
ratio: 0.9, loss: 0.013326372019946575
ratio: 0.95, loss: 0.0133

Running AWQ...:  92%|█████████▏| 22/24 [03:03<00:16,  8.47s/it]

dict_keys(['self_attn.q_proj', 'self_attn.k_proj', 'self_attn.v_proj', 'self_attn.out_proj', 'fc1', 'fc2'])
dict_keys(['self_attn.q_proj', 'self_attn.k_proj', 'self_attn.v_proj', 'self_attn.out_proj', 'fc1', 'fc2'])
----------
ratio: 0.0, loss: 0.013264557346701622
ratio: 0.05, loss: 0.01326411310583353
ratio: 0.1, loss: 0.013263403438031673
ratio: 0.15, loss: 0.013262837193906307
ratio: 0.2, loss: 0.013263304717838764
ratio: 0.25, loss: 0.013262695632874966
ratio: 0.3, loss: 0.013261978514492512
ratio: 0.35, loss: 0.013262277469038963
ratio: 0.4, loss: 0.013262133114039898
ratio: 0.45, loss: 0.013261237181723118
ratio: 0.5, loss: 0.013261817395687103
ratio: 0.55, loss: 0.013261503539979458
ratio: 0.6, loss: 0.013261623680591583
ratio: 0.65, loss: 0.013261726126074791
ratio: 0.7, loss: 0.013262137770652771
ratio: 0.75, loss: 0.013262114487588406
ratio: 0.8, loss: 0.013262744061648846
ratio: 0.85, loss: 0.013262959197163582
ratio: 0.9, loss: 0.0132635198533535
ratio: 0.95, loss: 0.01326

Running AWQ...:  96%|█████████▌| 23/24 [03:12<00:08,  8.57s/it]

dict_keys(['self_attn.q_proj', 'self_attn.k_proj', 'self_attn.v_proj', 'self_attn.out_proj', 'fc1', 'fc2'])
dict_keys(['self_attn.q_proj', 'self_attn.k_proj', 'self_attn.v_proj', 'self_attn.out_proj', 'fc1', 'fc2'])
----------
ratio: 0.0, loss: 0.0071968501433730125
ratio: 0.05, loss: 0.0071961767971515656
ratio: 0.1, loss: 0.007195735350251198
ratio: 0.15, loss: 0.007195128593593836
ratio: 0.2, loss: 0.007194618694484234
ratio: 0.25, loss: 0.007194302044808865
ratio: 0.3, loss: 0.007194048725068569
ratio: 0.35, loss: 0.007193746976554394
ratio: 0.4, loss: 0.007193713914602995
ratio: 0.45, loss: 0.007193513214588165
ratio: 0.5, loss: 0.007193679455667734
ratio: 0.55, loss: 0.007193423807621002
ratio: 0.6, loss: 0.007193758152425289
ratio: 0.65, loss: 0.007194096688181162
ratio: 0.7, loss: 0.00719437887892127
ratio: 0.75, loss: 0.007195138838142157
ratio: 0.8, loss: 0.00719570554792881
ratio: 0.85, loss: 0.007196441292762756
ratio: 0.9, loss: 0.007197371684014797
ratio: 0.95, loss: 0.00

Running AWQ...: 100%|██████████| 24/24 [03:20<00:00,  8.37s/it]


In [17]:
dump_awq = "awq_results.pt"
torch.save(awq_results, dump_awq)
print("AWQ results saved at", dump_awq)
#

AWQ results saved at awq_results.pt


In [None]:
### load awq

load_awq = "awq_results.pt"
awq_results = torch.load(load_awq, map_location="cpu")

In [13]:
q_config

{'zero_point': True,
 'q_group_size': 128,
 'w_n_bits': 8,
 'a_n_bits': 8,
 'act_quant': 'per_tensor'}

In [18]:

from transformers.models.opt.modeling_opt import OPTDecoderLayer
from awq.quantize.fake_quant_lester import quantize_opt_model

model = AutoModelForCausalLM.from_pretrained(
    model_path, torch_dtype=torch_dtype, device_map="auto"
)

# apply the AWQ results
apply_awq(model, awq_results)

model = quantize_opt_model(
    model,
    w_n_bits=q_config["w_n_bits"],
    a_n_bits=q_config["a_n_bits"],
    act_quant=q_config["act_quant"],
)


In [19]:
torch.cuda.empty_cache()
model.cuda()

OPTForCausalLM(
  (model): OPTModel(
    (decoder): OPTDecoder(
      (embed_tokens): Embedding(50272, 2048, padding_idx=1)
      (embed_positions): OPTLearnedPositionalEmbedding(2050, 2048)
      (final_layer_norm): LayerNorm((2048,), eps=1e-05, elementwise_affine=True)
      (layers): ModuleList(
        (0-23): 24 x OPTDecoderLayer(
          (self_attn): OPTAttention(
            (k_proj): Linear(in_features=2048, out_features=2048, bias=True)
            (v_proj): Linear(in_features=2048, out_features=2048, bias=True)
            (q_proj): Linear(in_features=2048, out_features=2048, bias=True)
            (out_proj): Linear(in_features=2048, out_features=2048, bias=True)
          )
          (activation_fn): ReLU()
          (self_attn_layer_norm): LayerNorm((2048,), eps=1e-05, elementwise_affine=True)
          (fc1): QuantizedLinear()
          (fc2): QuantizedLinear()
          (final_layer_norm): LayerNorm((2048,), eps=1e-05, elementwise_affine=True)
        )
      )
    )
 

In [20]:
# Evaluate the model
model_perplexity = evaluate(model, tokenizer)
model_size = get_model_size(model, data_width=q_config["w_n_bits"], group_size=128)
print(f"\nmodel perplexity: {model_perplexity:.2f}")
print(f"model size: {model_size/MiB:.2f} MiB")

evaluating...: 100%|██████████| 40/40 [00:22<00:00,  1.81it/s]


model perplexity: 13702.26
model size: 496.07 MiB





In [14]:
print("hi")

hi


## RANDOM STUFF

In [None]:
from awq.quantize.pre_quant_lester import get_blocks

In [None]:
for layer in get_blocks(model):
    # layer.fc1
    # layer.fc2
    # print("hi")
    break

In [None]:
from copy import deepcopy

x = deepcopy(layer)

In [None]:
x.fc1.weight.data[:] = 1.0
# x.fc1.weight.data

layer.fc1.weight.data

In [None]:
isinstance(layer, OPTDecoderLayer)