In [1]:
import torch
torch.backends.cuda.matmul.allow_tf32 = True

import os
import nip
from transformers import Trainer, TrainingArguments
from torch.utils.data import Dataset, DataLoader
from torch.utils.checkpoint import checkpoint
import sys
sys.path.append("/home/msst/repo/Quantization")
import qlib

from torch.optim.adamw import AdamW
from transformers.optimization import Adafactor
from bitsandbytes.optim.adamw import AdamW as AdamW8bit

DEVICE = 'cuda:0'

In [2]:
path_to_checkpoints = "/home/msst/repo/Quantization/logs/checkpoints_Llama2-7b-hf/trellis/"
chpnt_name = 'T256_L16_V2_K2_cbs10_LowBitSym_qtip_ptq_bs5'

qmodel = qlib.QuantizedLlamaForCausalLM.from_pretrained(
	os.path.join(path_to_checkpoints, chpnt_name),
    torch_dtype=torch.float16,
).to(DEVICE)
qmodel.cuda()
qmodel.train()
qmodel._set_gradient_checkpointing(enable=True, gradient_checkpointing_func=checkpoint)


model_name = 'Llama2-7b-hf'
tokenizer = qlib.load_tokenizer(model_name)

LOL, Im custom!


In [3]:
optimizer_cls = AdamW
#optimizer_cls = Adafactor

trainable_params_g1 = []
trainable_params_g2 = []
trainable_params_g3 = []
trainable_params_g4 = []
lr_g1 = 5e-4
lr_g2 = 5e-5
lr_g3 = 5e-4
lr_g4 = 1e-4


with torch.no_grad():
    for param_name, param in qmodel.named_parameters():
        if ('SU' in param_name) or ('SV' in param_name):
            param.data = param.data.to(torch.float32)
            trainable_params_g1.append(param)
            param.requires_grad = True
        if ('lm_head' in param_name) or ('norm' in param_name):
        #if ('norm' in param_name):
            param.data = param.data.to(torch.float32)
            trainable_params_g2.append(param)
            param.requires_grad = True
        if ('scales' in param_name):
            param.data = param.data.to(torch.float32)
            trainable_params_g3.append(param)
            param.requires_grad = True
        # if ('act_scale' in param_name):
        #     param.data = param.data.to(torch.float32)
        #     trainable_params_g4.append(param)
        #     param.requires_grad = True

optimizer_kwargs = {
    'optimizer_dict': [
        {
            'params': trainable_params_g1,
            'lr': lr_g1, 
            'weight_decay': 0.0
        },
        {
            'params': trainable_params_g2,
            'lr': lr_g2, 
            'weight_decay': 0.0
        },
        {
            'params': trainable_params_g3,
            'lr': lr_g3, 
            'weight_decay': 0.0
        },
        # {
        #     'params': trainable_params_g4,
        #     'lr': lr_g4, 
        #     'weight_decay': 0.0
        # },
    ]
}

N_FP_MODEL_PARAMS = 6738415616 # Llama2-7B

def print_number_of_params(group):
    n_trainable_params = sum(p.numel() for p in group)
    fraq_of_fp_model_params = 100 * n_trainable_params / N_FP_MODEL_PARAMS
    print(
        f'Trainalble params: {n_trainable_params:.3e}',
        f'Fraction of fp model params: {fraq_of_fp_model_params:.3f}%',
    )

for g in (trainable_params_g1, 
          trainable_params_g2, 
          trainable_params_g3, 
          trainable_params_g4):
    print_number_of_params(g)

Trainalble params: 2.499e+06 Fraction of fp model params: 0.037%
Trainalble params: 1.313e+08 Fraction of fp model params: 1.949%
Trainalble params: 2.530e+07 Fraction of fp model params: 0.375%
Trainalble params: 0.000e+00 Fraction of fp model params: 0.000%


In [None]:
n_steps = 125
grad_acc = 20
loss_type = 'CE'


if loss_type=='KD' or loss_type=='KD+CE':
    kd_data_path = '/mnt/ssd_storage/ml/weights/vc_data/Llama2-7b-hf/kd_data'
    #dataset_name = 'kd_data_redpajama_decoder_output_small'
    dataset_name = 'kd_data_redpajama_decoder_output'
    kd_data = torch.load(
        f'{kd_data_path}/{dataset_name}.pth',
        weights_only=True
    )
    train_dataset = qlib.KnowledgeDistillationDataset(kd_data)
else:
    train_dataset = qlib.QATDataset(
        config=nip.load('/home/msst/repo/Quantization/configs/data/redpajama_train_seqlen4096_large.yaml'),
        tokenizer=qlib.load_tokenizer('Llama2-7b-hf'),
        return_dict=True
    )


training_args = TrainingArguments(
    max_steps=n_steps,
    per_device_train_batch_size=1,
    gradient_accumulation_steps=grad_acc,
    gradient_checkpointing=True,
    
    logging_strategy="steps",
    logging_steps=1,
    output_dir = './output_dir',
    save_strategy="no",
    
    # label_names=[],
    # per_device_eval_batch_size=1,
    # eval_strategy='steps',
    # eval_steps=10,
    # eval_on_start=True,

    remove_unused_columns=False,
)

if loss_type=='KD' or loss_type=='KD+CE':
    fp_model = qlib.load_model('Llama2-7b-hf', torch_dtype=torch.float16)
    lm_head = fp_model.lm_head.to(torch.float32).to(DEVICE)

class CustomTrainer(Trainer):
    def compute_loss(self, model, inputs, return_outputs=False, num_items_in_batch=None):
        outputs = model(inputs['input_ids'], labels=inputs['input_ids'])
        if loss_type=='CE' or loss_type=='KD+CE':
            ce_loss = outputs.loss
        
        if loss_type=='KD' or loss_type=='KD+CE':
            qmodel_logits = outputs.logits
            fpmodel_logits = lm_head(inputs['decoder_output'].to(torch.float32))
            
            n_tokens = torch.prod(torch.tensor(qmodel_logits.shape[:-1]))
            
            T = 1 #2
            kd_loss = torch.nn.functional.kl_div(
                    torch.log_softmax(qmodel_logits / T, dim=-1),
                    torch.softmax(fpmodel_logits / T, dim=-1),
                    reduction='batchmean',
                ) * (T**2) / n_tokens
        
        if loss_type=='KD':
            total_loss = kd_loss
        elif loss_type=='CE':
            total_loss = ce_loss
        elif loss_type=='KD+CE':
            print("kd_loss:", kd_loss.item(), 'ce_loss:', ce_loss.item())
            total_loss = ce_loss + 10 * kd_loss

        return (total_loss, outputs) if return_outputs else total_loss

print('scales:', qmodel.get_decoder().layers[31].mlp.up_proj.scales.dtype)
print('SU:', qmodel.get_decoder().layers[31].mlp.up_proj.SU.dtype)
print('lm_head:', qmodel.lm_head.weight.dtype)

trainer = CustomTrainer(
    model=qmodel,
    args=training_args,
    train_dataset=train_dataset,
    #eval_dataset=eval_dataset,
    optimizer_cls_and_kwargs=(optimizer_cls, optimizer_kwargs)
)

trainer.can_return_loss = True
with torch.amp.autocast('cuda', dtype=torch.float16):
   trainer.train()


Resolving data files:   0%|          | 0/200 [00:00<?, ?it/s]

scales: torch.float32
SU: torch.float32
lm_head: torch.float32


  return Variable._execution_engine.run_backward(  # Calls into the C++ engine to run the backward pass


Step,Training Loss


In [5]:
path_to_save = os.path.join(
	path_to_checkpoints, 
	f'{chpnt_name}_qat_{loss_type}_{n_steps}steps_{grad_acc}ga_{optimizer_cls.__name__}'
)
qmodel.half().save_pretrained(path_to_save)