In [1]:
import torch
torch.backends.cuda.matmul.allow_tf32 = True

import nip
from transformers import Trainer, TrainingArguments
from torch.utils.data import Dataset, DataLoader
import sys
sys.path.append("/home/msst/repo/Quantization")
import qlib


from bitsandbytes.optim.adamw import AdamW as AdamW8bit
from torch.optim.adamw import AdamW
from transformers.optimization import Adafactor


DEVICE = 'cuda:0'

In [2]:
qmodel = qlib.QuantizedLlamaForCausalLM.from_pretrained(
	'/home/msst/repo/Quantization/logs/checkpoints_Llama2-7b-hf/trellis/T256_L16_V2_K2_lbits10_LowBitSym_qtip',
    torch_dtype=torch.float16,
).to(DEVICE)

model_name = 'Llama2-7b-hf'
tokenizer = qlib.load_tokenizer(model_name)

LOL, Im custom!


Some weights of the model checkpoint at /home/msst/repo/Quantization/logs/checkpoints_Llama2-7b-hf/trellis/T256_L16_V2_K2_lbits10_LowBitSym_qtip were not used when initializing QuantizedLlamaForCausalLM: {'model.layers.4.self_attn.v_proj.weight_quantizer.state_candidates', 'model.layers.11.mlp.gate_proj.weight_quantizer.state_candidates', 'model.layers.12.self_attn.q_proj.weight_quantizer.state_candidates', 'model.layers.23.self_attn.v_proj.weight_quantizer.sumdelta', 'model.layers.10.mlp.up_proj.weight_quantizer.state_candidates', 'model.layers.26.self_attn.v_proj.weight_quantizer.sumdelta', 'model.layers.16.self_attn.q_proj.weight_quantizer.sumdelta', 'model.layers.4.self_attn.o_proj.weight_quantizer.sumdelta', 'model.layers.9.self_attn.v_proj.weight_quantizer.state_candidates', 'model.layers.6.mlp.up_proj.weight_quantizer.sumdelta', 'model.layers.11.self_attn.o_proj.weight_quantizer.state_candidates', 'model.layers.16.self_attn.o_proj.weight_quantizer.sumdelta', 'model.layers.6.mlp.

In [3]:
class KnowledgeDistillationDataset(Dataset):
    def __init__(self, data):
        self.data = data
        #self.lm_head = lm_head.cpu() #.to(DEVICE)
        self.len = len(self.data['decoder_output'])

    def __len__(self):
        return self.len

    def __getitem__(self, index):
        assert index < self.len
        input_ids = self.data['input'][index]
        decoder_output = self.data['decoder_output'][index]
        return {
            'input_ids' : input_ids[0],
            'decoder_output' : decoder_output[0]
        }

kd_data_path = '/mnt/ssd_storage/ml/weights/vc_data/Llama2-7b-hf/kd_data'
#dataset_name = 'kd_data_redpajama_decoder_output_small'
dataset_name = 'kd_data_redpajama_decoder_output'
kd_data = torch.load(
    f'{kd_data_path}/{dataset_name}.pth', 
    weights_only=True
)

kd_dataset = KnowledgeDistillationDataset(kd_data)

In [4]:
# for param_name, param in qmodel.named_parameters():
#     print(param_name)

trainable_params = []

with torch.no_grad():
    qmodel.half()
    for param_name, param in qmodel.named_parameters():
        if ('SU' in param_name) or ('SV' in param_name) or ('norm' in param_name):
            trainable_params.append(param)
            param.requires_grad = True
        

optimizer_cls = Adafactor
optimizer_kwargs = {
    'optimizer_dict': [
        {
            'params': trainable_params,
            'lr': 1e-4, 
            'weight_decay': 0.0
        },
    ]
}

In [5]:
from torch.utils.checkpoint import checkpoint
qmodel.cuda()
qmodel.train()
qmodel._set_gradient_checkpointing(enable=True, gradient_checkpointing_func=checkpoint)

In [None]:
n_steps = 500

training_args = TrainingArguments(
    max_steps=n_steps,
    per_device_train_batch_size=1,
    gradient_accumulation_steps=4,
    gradient_checkpointing=True,
    
    logging_strategy="steps",
    logging_steps=1,
    output_dir = './output_dir',
    save_strategy="no",
    
    # label_names=[],
    # per_device_eval_batch_size=1,
    # eval_strategy='steps',
    # eval_steps=10,
    # eval_on_start=True,

    remove_unused_columns=False,
)

fp_model = qlib.load_model('Llama2-7b-hf', torch_dtype=torch.float16)
lm_head = fp_model.lm_head.to(DEVICE)


n_trainable_params = sum(p.numel() for p in trainable_params)
fp_model_params = sum(p.numel() for p in fp_model.parameters())
fraq_of_fp_model_params = 100 * n_trainable_params / fp_model_params
print(
	'Trainalble params:', n_trainable_params,
	f'Fraction of fp model: {fraq_of_fp_model_params}%',
) 


class CustomTrainer(Trainer):
    def compute_loss(self, model, inputs, return_outputs=False, num_items_in_batch=None):
        outputs = model(inputs['input_ids'], labels=inputs['input_ids'])
        ce_loss = outputs.loss
        print("ce_loss:", ce_loss)

        qmodel_logits = outputs.logits
        fpmodel_logits = lm_head(inputs['decoder_output'])
        
        n_tokens = torch.prod(torch.tensor(qmodel_logits.shape[:-1]))
        
        T = 1 #2
        kd_loss = torch.nn.functional.kl_div(
                torch.log_softmax(qmodel_logits / T, dim=-1),
                torch.softmax(fpmodel_logits / T, dim=-1),
                reduction='batchmean',
            ) * (T**2) / n_tokens
        
        # print("kd_loss:", kd_loss.item(), 'ce_loss:', ce_loss.item())
        # sim = torch.argmax(qmodel_logits, dim=-1) == torch.argmax(fpmodel_logits, dim=-1)
        # print('right_labels ratio:', sim.sum().item() / n_tokens)

        #total_loss = ce_loss + 0.1 * kd_loss
        total_loss = kd_loss

        return (total_loss, outputs) if return_outputs else total_loss


trainer = CustomTrainer(
    model=qmodel,
    args=training_args,
    #train_dataset=train_dataset,
    train_dataset=kd_dataset,
    #eval_dataset=eval_dataset,
    optimizer_cls_and_kwargs=(optimizer_cls, optimizer_kwargs)
)

trainer.can_return_loss = True
with torch.amp.autocast('cuda', dtype=torch.float16):
    trainer.train()

Loading checkpoint shards:   0%|          | 0/2 [00:00<?, ?it/s]

Trainalble params: 2764800 Fraction of fp model: 0.04103041660765378%
ce_loss: tensor(1.4146, device='cuda:0', grad_fn=<NllLossBackward0>)


  return Variable._execution_engine.run_backward(  # Calls into the C++ engine to run the backward pass


ce_loss: tensor(3.8428, device='cuda:0', grad_fn=<NllLossBackward0>)


In [7]:
#qmodel.save_pretrained('/home/msst/repo/Quantization/logs/checkpoints_Llama2-7b-hf/trellis/T256_L16_V2_K2_lbits10_LowBitSym_qtip' + f'_QAT_KDLoss_{n_steps}steps')
qmodel.save_pretrained('/home/msst/repo/Quantization/logs/checkpoints_Llama2-7b-hf/trellis/T256_L16_V2_K2_lbits10_LowBitSym_qtip' + f'_QAT_KDLoss_425steps')