In [1]:
import torch
torch.backends.cuda.matmul.allow_tf32 = True

import nip
from transformers import Trainer, TrainingArguments
import sys
sys.path.append("/home/msst/repo/Quantization")
import qlib


from bitsandbytes.optim.adamw import AdamW as AdamW8bit
from torch.optim.adamw import AdamW
from transformers.optimization import Adafactor


DEVICE = 'cuda:0'

In [2]:
path_to_model = '/home/msst/repo/Quantization/logs/checkpoints_Llama2-7b-hf/SymQuant/cb256_vecdim8_weightPERCOORD_scaleNone_distMSE_blocksizeNone_iters15_abscoords_haar2_ptq.pth'
qmodel = torch.load(path_to_model)
fp_model = qlib.load_model('Llama2-7b-hf', torch_dtype=torch.float16)
tokenizer = qlib.load_tokenizer('Llama2-7b-hf')

  qmodel = torch.load(path_to_model)


Loading checkpoint shards:   0%|          | 0/2 [00:00<?, ?it/s]

In [3]:
quant_classes=[qlib.SymHQLinear,]

# use_names = ['model.layers.0.', 
#              'model.layers.1.',
#              'model.layers.28.', 
#              'model.layers.29.', 
#              'model.layers.30.', 
#              'model.layers.31.']

#use_names = ['model.layers.0.self_attn.q_proj']
#use_names = ['self_attn.o_proj', 'self_attn.q_proj',]
#use_names = ['mlp.down_proj',]
use_names = []

def check_layer(name, use_names):
    for use_name in use_names:
        if use_name in name:
            return True
    return False

for qmodule_name, qmodule in qmodel.named_modules():
    if (qmodule.__class__ in quant_classes) and check_layer(qmodule_name, use_names):
        print(qmodule_name)
        fp_weight = fp_model.get_submodule(qmodule_name).weight.detach()
        qmodule.fp_weight = fp_weight.to(DEVICE)
        qmodule.reassine_params = {
            'batch_size' : 2**14,
            'reassign_step' : None, #8, #2,
        }
        qmodule.trainable = 'qat'
        qmodule.metadata = {
            'new_indices_ratio': [],
        }
        del qmodule.signs

In [4]:
from torch.utils.data import Dataset, DataLoader

class KnowledgeDistillationDataset(Dataset):
    def __init__(self, data):
        self.data = data
        #self.lm_head = lm_head.cpu() #.to(DEVICE)
        self.len = len(self.data['decoder_output'])

    def __len__(self):
        return self.len

    def __getitem__(self, index):
        assert index < self.len
        input_ids = self.data['input'][index]
        decoder_output = self.data['decoder_output'][index]
        return {
            'input_ids' : input_ids[0],
            'decoder_output' : decoder_output[0]
        }

kd_data_path = '/mnt/ssd_storage/ml/weights/vc_data/Llama2-7b-hf/kd_data'
dataset_name = 'kd_data_redpajama_decoder_output_small'
kd_data = torch.load(
    f'{kd_data_path}/{dataset_name}.pth', 
    weights_only=True
)

kd_dataset = KnowledgeDistillationDataset(kd_data)

In [5]:
cb_params = []

with torch.no_grad():
    qmodel.half()
    for param_name, param in qmodel.named_parameters():
        if 'codebook' in param_name:
            cb_params.append(param)
            param.requires_grad = True


optimizer_cls = Adafactor
optimizer_kwargs = {
    'optimizer_dict': [
        {
            'params': cb_params,
            'lr': 1e-4, 
            'weight_decay': 0.0
        },
    ]
}

In [6]:
# eval_dataset = qlib.QATDataset(
#     config=nip.load('/home/msst/repo/Quantization/configs/data/wikitext_test_seqlen4096.yaml'),
#     tokenizer=qlib.load_tokenizer('Llama2-7b-hf'),
# 	return_dict=True
# )

# train_dataset = qlib.QATDataset(
#     config=nip.load('/home/msst/repo/Quantization/configs/data/redpajama_train_seqlen4096.yaml'),
# 	#config=nip.load('/home/msst/repo/Quantization/configs/data/wikitext_test_seqlen4096.yaml'),
#     tokenizer=qlib.load_tokenizer('Llama2-7b-hf'),
#     return_dict=True
# )

#len(eval_dataset), len(train_dataset)

In [7]:
from torch.utils.checkpoint import checkpoint
qmodel.cuda()
qmodel.train()
qmodel._set_gradient_checkpointing(enable=True, gradient_checkpointing_func=checkpoint)

In [8]:
training_args = TrainingArguments(
    max_steps=32,
    per_device_train_batch_size=1,
    gradient_accumulation_steps=4,
    gradient_checkpointing=True,
    
    logging_strategy="steps",
    logging_steps=1,
    output_dir = './output_dir',
    save_strategy="no",
    
    # label_names=[],
    # per_device_eval_batch_size=1,
    # eval_strategy='steps',
    # eval_steps=10,
    # eval_on_start=True,

    remove_unused_columns=False,
)


lm_head = fp_model.lm_head.to(DEVICE)

class CustomTrainer(Trainer):
    # def compute_loss(self, model, inputs, return_outputs=False, num_items_in_batch=None):
    #     outputs = model(inputs['input_ids'], labels=inputs['input_ids'])
    #     loss = outputs.loss
    #     return (loss, outputs) if return_outputs else loss
    
    def compute_loss(self, model, inputs, return_outputs=False, num_items_in_batch=None):
        outputs = model(inputs['input_ids'], labels=inputs['input_ids'])
        ce_loss = outputs.loss

        qmodel_logits = outputs.logits
        fpmodel_logits = lm_head(inputs['decoder_output'])
        
        n_tokens = torch.prod(torch.tensor(qmodel_logits.shape[:-1]))
        
        T = 2
        kd_loss = torch.nn.functional.kl_div(
                torch.log_softmax(qmodel_logits / T, dim=1),
                torch.softmax(fpmodel_logits / T, dim=1),
                reduction='batchmean',
            ) * (T**2) / n_tokens
        
        # print("kd_loss:", kd_loss.item(), 'ce_loss:', ce_loss.item())
        # sim = torch.argmax(qmodel_logits, dim=-1) == torch.argmax(fpmodel_logits, dim=-1)
        # print('right_labels ratio:', sim.sum().item() / n_tokens)

        total_loss = ce_loss + 0.1 * kd_loss

        return (total_loss, outputs) if return_outputs else total_loss


trainer = CustomTrainer(
    model=qmodel,
    args=training_args,
    #train_dataset=train_dataset,
    train_dataset=kd_dataset,
    #eval_dataset=eval_dataset,
    optimizer_cls_and_kwargs=(optimizer_cls, optimizer_kwargs)
)

trainer.can_return_loss = True
with torch.amp.autocast('cuda', dtype=torch.float16):
    trainer.train()

Step,Training Loss
1,19.8336
2,57.8154


KeyboardInterrupt: 

In [9]:
from qlib.utils.pack_effective import pack_bool_tensor
quant_classes=[qlib.SymHQLinear,]

for qmodule_name, qmodule in qmodel.named_modules():
    if (qmodule.__class__ in quant_classes) and hasattr(qmodule, 'reassine_params'):
        del qmodule.reassine_params
    if (qmodule.__class__ in quant_classes) and hasattr(qmodule, 'fp_weight'):
        packed_signs, _ = pack_bool_tensor((1+torch.sign(qmodule.fp_weight)).bool())
        qmodule.register_buffer('signs', packed_signs)
        del qmodule.fp_weight          
    if (qmodule.__class__ in quant_classes) and hasattr(qmodule, 'new_indices_ratio'):
        del qmodule.new_indices_ratio      
        
    qmodule.trainable = False

torch.save(qmodel, path_to_model[:-4] + '_qat_kd_loss.pth')

In [10]:
# # enable gradient_checkpointing

# from torch.utils.checkpoint import checkpoint
# import functools

# gradient_checkpointing_func = functools.partial(checkpoint)

# for module in self.modules():
#     if hasattr(module, "gradient_checkpointing"):
#         module._gradient_checkpointing_func = gradient_checkpointing_func
#         module.gradient_checkpointing = enable
#         is_gradient_checkpointing_set = True
