In [1]:
%env CUDA_VISIBLE_DEVICES=3
%env TRANSFORMERS_CACHE=/mnt/LLM/hub
%env HF_HOME=/mnt/LLM/
%env OMP_NUM_THREADS=16
%load_ext autoreload
%autoreload 2

import os
import sys
sys.path.insert(0, '..')

import time
import random
from tqdm.auto import trange, tqdm
import numpy as np
import ipynbname  # pip install ipynbname

import torch
import torch.nn as nn
import torch.nn.functional as F
import transformers

from src.aq import QuantizedWeight, QuantizedLinear
from src.modelutils import get_model
from src.datautils import get_loaders
from convert_legacy_model_format import load_quantized_model_with_old_pickle


torch.set_num_threads(16)
torch.backends.cudnn.allow_tf32 = False
torch.backends.cuda.matmul.allow_tf32 = False
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')


env: CUDA_VISIBLE_DEVICES=3
env: TRANSFORMERS_CACHE=/mnt/LLM/hub
env: HF_HOME=/mnt/LLM/
env: OMP_NUM_THREADS=16




In [2]:
class args:
    base_model = "TinyLlama/TinyLlama-1.1B-intermediate-step-1431k-3T"
    quant_model = "/extra_disk_1/vahe1994/AQ/tinyllama-3t-2x8g8/"
    dtype = 'bfloat16'
    model_seqlen = 1024  # can be 2048 for 1.1B, 4096-8192 for larger models
    device_map = 'auto'
    
    dataset = 'pajama'
    step_nsamples = 256
    total_nsamples = 2560#2560
    seed = 42
    beam_size = 4
    stochastic_rounding_tau = 0.0
    max_code_change_per_step=0.1 #1e-3
    code_trust_ratio=0.1
    entropy_reg=0.1
    code_selection_temperature=100
    
    code_lr = 1e-3
    code_lr_plateau_scale = 0.5
    code_betas = (0.0, 0.95)
    delta_decay = 1.0
    codebook_lr = 1e-5
    codebook_betas = (0.9, 0.95)
    codebook_grad_accumulation_steps = 8
    
    
    autocast_dtype = torch.bfloat16  # bfloat16 or None (not using grad scaler!)
    training_dtype = torch.float32
    gradient_checkpointing = False
    devices = [device]

In [3]:
train_data = get_loaders(
    args.dataset,
    nsamples=args.total_nsamples,
    seed=args.seed,
    model_path=args.base_model,
    seqlen=args.model_seqlen,
)

Loading red_pajama from togethercomputer/RedPajama-Data-1T-Sample


You can avoid this message in future by passing the argument `trust_remote_code=True`.
Passing `trust_remote_code=True` will be mandatory to load this dataset from the next major release of `datasets`.
                                                                                

Loaded data from pajama; len(data)=2560 sequences




In [4]:
base_model = get_model(args.base_model, None, args.dtype, args.device_map)
if not args.device_map:
    base_model = base_model.to(device)

quantized_model = load_quantized_model_with_old_pickle(
    args.base_model, args.quant_model, dtype=args.dtype, device_map=args.device_map)
if not args.device_map:
    quantized_model = quantized_model.to(device)
quantized_model = quantized_model.to(args.training_dtype)

Loading pretrained model ...
Model loaded sucсessfully ...
Loading pretrained model ...
Model loaded sucсessfully ...
Initializing model with random weights...
Loading quantized model ...
Model loaded sucсessfully ...
found 154 quantized weight matrices




In [5]:
from src.pv_utils import create_dequantized_model
dequantized_model, master_parameters = create_dequantized_model(
    quantized_model, reuse_non_quantized=True, dequantized_dtype=args.autocast_dtype
)
for param in dequantized_model.parameters():
    param.data = param.data.to(args.autocast_dtype)

In [6]:
from src.pv_optimizer import StraightThroughAdamW

optimizer = StraightThroughAdamW(
    named_dequantized_params=dict(dequantized_model.named_parameters()),
    named_quantized_params=master_parameters,
    
    update_codes=dict(lr=args.code_lr, betas=args.code_betas),
#     update_codebooks_and_scales=dict(lr=args.codebook_lr, betas=args.codebook_betas),
#     update_non_quantized_parameters=dict(lr=args.codebook_lr, betas=args.codebook_betas),
    code_trust_ratio=args.code_trust_ratio,
    beam_size=args.beam_size,
    max_code_change_per_step=args.max_code_change_per_step,
    stochastic_rounding_tau=args.stochastic_rounding_tau,
    straight_through_buffer_dtype=torch.float32,
    entropy_reg=args.entropy_reg,
    code_selection_temperature=args.code_selection_temperature
)

In [7]:
if args.gradient_checkpointing:
    quantized_model.gradient_checkpointing_enable()
    quantized_model.enable_input_require_grads()
    for module in quantized_model.modules():
        if isinstance(module, QuantizedLinear):
            module.use_checkpoint = True
    dequantized_model.gradient_checkpointing_enable()
    dequantized_model.enable_input_require_grads()


In [8]:
def _run_one_step(args, base_model, dequantized_model, optimizer, train_data, **kwargs):
    optimizer.zero_grad(set_to_none=True)
    with tqdm(train_data, desc="V step") as progress:

        total_loss = 0.0
        for i, batch in enumerate(progress):
            batch = torch.as_tensor(batch, device=device)
            with torch.no_grad():
                teacher_logits = base_model(batch).logits
            student_logits = dequantized_model(batch).logits  # forward accumulates XTX statistics
            loss = kl_div(student_logits, teacher_logits)
            (loss / len(train_data)).backward()  # backward accumulates gradient
            total_loss = loss.item() / (i + 1) + total_loss * i / (i + 1)
            progress.desc = f"V step: accumulating gradients, loss = {total_loss:.9f}"
            del student_logits, teacher_logits, loss
    optimizer.step(**kwargs)
    optimizer.zero_grad(set_to_none=True)  # reset statistics for the next step
    return total_loss


def kl_div(student_hiddens, teacher_hiddens):
    C = student_hiddens.shape[-1]  # num classes
    return F.kl_div(
        input=F.log_softmax(student_hiddens.view(-1, C), dim=-1),
        target=F.log_softmax(teacher_hiddens.view(-1, C), dim=-1),
        log_target=True,
        reduction="batchmean",
    )


In [9]:
eval_data = get_loaders(
    'wikitext2',
    seed=args.seed,
    model_path=args.base_model,
    seqlen=args.model_seqlen,
    eval_mode=True,
)

@torch.inference_mode()
def eval_ppl_naive(model, eval_data):
    eval_inps = [
        eval_data[:, start: start + args.model_seqlen] for start in range(0, eval_data.shape[1], args.model_seqlen)
    ]
    total_tokens = 0
    nlls = []
    for input_ids in tqdm(eval_inps):
        input_ids = input_ids.to(device)
        lm_logits = model(input_ids).logits

        shift_logits = lm_logits[:, :-1, :].contiguous()
        shift_labels = input_ids[:, 1:]
        loss_fct = nn.CrossEntropyLoss()
        loss = loss_fct(shift_logits.view(-1, shift_logits.size(-1)), shift_labels.view(-1))
        neg_log_likelihood = loss.float() * args.model_seqlen
        nlls.append(neg_log_likelihood)
        total_tokens += shift_labels.numel()
    ppl = torch.exp(torch.stack(nlls).sum() / total_tokens)
    return ppl

Loaded data from wikitext2; len(data)=1 sequences


In [10]:
POINTER = 0
def next_train_data():
    global POINTER
    batch = []
    for i in range(args.step_nsamples):
        batch.append(train_data[POINTER % len(train_data)])
        POINTER += 1
    return batch

In [None]:
last_ppl_eval_score = None
for i in range(100):
    print("STEP", i)
    if i % 5 == 0:
        ppl_eval_score = eval_ppl_naive(dequantized_model, eval_data).item()
        last_ppl_eval_score = ppl_eval_score
        print("wikitext2 ppl (naive)", ppl_eval_score)

    v_step_train_loss = _run_one_step(args, base_model, dequantized_model, optimizer, next_train_data())
    print("train loss:", v_step_train_loss)


STEP 0


  0%|          | 0/334 [00:00<?, ?it/s]

We detected that you are passing `past_key_values` as a tuple and this is deprecated and will be removed in v4.43. Please use an appropriate `Cache` class (https://huggingface.co/docs/transformers/v4.41.3/en/internal/generation_utils#transformers.Cache)


wikitext2 ppl (naive) 18.941404342651367


V step:   0%|          | 0/256 [00:00<?, ?it/s]

AVG entropy(not correct): 7.960147365347131
train loss: 0.6543046110309659
STEP 1


V step:   0%|          | 0/256 [00:00<?, ?it/s]

AVG entropy(not correct): 7.9397684598897955
train loss: 0.5452335950685666
STEP 2


V step:   0%|          | 0/256 [00:00<?, ?it/s]

AVG entropy(not correct): 7.915862863714045
train loss: 0.6128921939525758
STEP 3


V step:   0%|          | 0/256 [00:00<?, ?it/s]

AVG entropy(not correct): 7.889609107723484
train loss: 0.6746042658924123
STEP 4


V step:   0%|          | 0/256 [00:00<?, ?it/s]

AVG entropy(not correct): 7.860826043339519
train loss: 0.7120309435995306
STEP 5


  0%|          | 0/334 [00:00<?, ?it/s]

wikitext2 ppl (naive) 21.438142776489258


V step:   0%|          | 0/256 [00:00<?, ?it/s]

AVG entropy(not correct): 7.829715837131847
train loss: 0.7723719145869837
STEP 6


V step:   0%|          | 0/256 [00:00<?, ?it/s]

AVG entropy(not correct): 7.797919966957786
train loss: 0.8470790250576108
STEP 7


V step:   0%|          | 0/256 [00:00<?, ?it/s]

AVG entropy(not correct): 7.764575208936419
train loss: 0.9057074269512672
STEP 8


V step:   0%|          | 0/256 [00:00<?, ?it/s]