In [1]:
from peft import LoraConfig, get_peft_model, prepare_model_for_kbit_training
from loguru import logger
from modeling_classifier import BaichuanForSequenceClassification
import bitsandbytes as bnb
from transformers import BitsAndBytesConfig
from trl import DPOTrainer, get_kbit_device_map
import torch 

from transformers import AutoConfig, AutoTokenizer

def load_tokenizer(args):
    config = AutoConfig.from_pretrained(args.model_name_or_path, trust_remote_code=True)
    # 加载tokenzier
    tokenizer = AutoTokenizer.from_pretrained(
        args.model_name_or_path,
        trust_remote_code=True,
        # llama不支持fast
        use_fast=True
    )
    if tokenizer.pad_token is None:
        tokenizer.pad_token = tokenizer.eos_token
    assert tokenizer.pad_token_id is not None, "pad_token_id should not be None"
    assert tokenizer.eos_token_id is not None, "eos_token_id should not be None"
    logger.info(f'vocab_size of tokenizer: {tokenizer.vocab_size}')
    return tokenizer

def load_model(args):
    """
    加载模型
    """
    logger.info(f'Loading model from base model: {args.model_name_or_path}')
    logger.info(f'Train model with {args.train_mode}')

    # init model kwargs
    # todo add flash attention
    # attn_implementation = None
    torch_dtype = torch.float16 if args.fp16 else torch.float32
    if args.train_mode == 'qlora':
        quantization_config = BitsAndBytesConfig(
            load_in_4bit=True,
            bnb_4bit_compute_dtype=torch.float16, # if training_args.args.fp16 else torch.bfloat16,
            bnb_4bit_use_double_quant=True,
            bnb_4bit_quant_type="nf4",
            llm_int8_threshold=6.0,
            llm_int8_has_fp16_weight=False,
        )
    else:
        quantization_config = None
    logger.info(quantization_config)
    model_kwargs = dict(
        trust_remote_code=True,
        # attn_implementation=attn_implementation,
        torch_dtype=torch_dtype,
        use_cache=False if args.gradient_checkpointing else True,
        device_map=get_kbit_device_map() if quantization_config is not None else None,
        quantization_config=quantization_config,
    )
    model = BaichuanForSequenceClassification.from_pretrained(args.model_name_or_path, **model_kwargs)

    # moe模型，需要考虑负载均衡的loss
    if 'output_router_logits' in model.config.to_dict():
        logger.info('set output_router_logits as True')
        model.config.output_router_logits = True
    # QLoRA: casts all the non int8 modules to full precision (fp32) for stability
    if args.train_mode == 'qlora':
        model = prepare_model_for_kbit_training(model, use_gradient_checkpointing=args.gradient_checkpointing)
    # LoRA: Enables the gradients for the input embeddings
    if args.train_mode == 'lora':
        # For backward compatibility
        if hasattr(model, "enable_input_require_grads"):
            model.enable_input_require_grads()
        else:
            def make_inputs_require_grad(module, input, output):
                output.requires_grad_(True)
            model.get_input_embeddings().register_forward_hook(make_inputs_require_grad)

    # init peft_config
#     peft_config = None
    if args.train_mode == 'full':
        peft_config = None
    else:
        # 找到所有需要插入adapter的全连接层
        #target_modules = find_all_linear_names(model, args.train_mode)
        target_modules = ['W_pack']
        peft_config = LoraConfig(
            r=args.lora_rank,
            lora_alpha=args.lora_alpha,
            target_modules=target_modules,
            lora_dropout=args.lora_dropout,
            bias="none",
            task_type=args.task_type,
        )

    # init peft model
    if args.train_mode in ['lora', 'qlora']:
        model = get_peft_model(model, peft_config)
        model.print_trainable_parameters()
        logger.info(f'memory footprint of model: {model.get_memory_footprint() / (1024 * 1024 * 1024)} GB')


    # 计算模型参数量
    total = sum(p.numel() for p in model.parameters())
    logger.info("Total model params: %.2fM" % (total / 1e6))

    return {
        'model': model,
        'peft_config': peft_config
    }

def memory_stats():
    print(torch.cuda.memory_allocated()/1024**2)
    print(torch.cuda.memory_cached()/1024**2)

    PyTorch 2.1.0a0+4136153 with CUDA 1201 (you have 2.0.0+cu117)
    Python  3.10.12 (you have 3.10.12)
  Please reinstall xformers (see https://github.com/facebookresearch/xformers#installing-xformers)
  Memory-efficient attention, SwiGLU, sparse and more won't be available.
  Set XFORMERS_MORE_DETAILS=1 for more details


[2024-03-01 17:53:12,096] [INFO] [real_accelerator.py:191:get_accelerator] Setting ds_accelerator to cuda (auto detect)


In [2]:
class Arguments:
    def __init__(self):
        self.model_name_or_path = '/DATA/jupyter/share/LLM_NBS/Baichuan2-7B-Chat/'
        self.fp16 = True
        self.train_mode = 'qlora'
        self.lora_rank=64
        self.lora_dropout=0.05
        self.lora_alpha = 16
        self.gradient_checkpointing = False
        self.task_type='SEQ_CLS'
template = dict(
    template_name='baichuan2',
    system_format=None,
    user_format='<reserved_106>{content}<reserved_107>',
    assistant_format='{content}</s>',
    system=None,
    stop_word='</s>'
)

torch.cuda.empty_cache()
args = Arguments()
tokenizer = load_tokenizer(args)
d = load_model(args)
model = d['model']

[32m2024-03-01 17:53:15.952[0m | [1mINFO    [0m | [36m__main__[0m:[36mload_tokenizer[0m:[36m24[0m - [1mvocab_size of tokenizer: 125696[0m
[32m2024-03-01 17:53:15.953[0m | [1mINFO    [0m | [36m__main__[0m:[36mload_model[0m:[36m31[0m - [1mLoading model from base model: /DATA/jupyter/share/LLM_NBS/Baichuan2-7B-Chat/[0m
[32m2024-03-01 17:53:15.954[0m | [1mINFO    [0m | [36m__main__[0m:[36mload_model[0m:[36m32[0m - [1mTrain model with qlora[0m
[32m2024-03-01 17:53:15.956[0m | [1mINFO    [0m | [36m__main__[0m:[36mload_model[0m:[36m49[0m - [1mBitsAndBytesConfig {
  "bnb_4bit_compute_dtype": "float16",
  "bnb_4bit_quant_type": "nf4",
  "bnb_4bit_use_double_quant": true,
  "llm_int8_enable_fp32_cpu_offload": false,
  "llm_int8_has_fp16_weight": false,
  "llm_int8_skip_modules": null,
  "llm_int8_threshold": 6.0,
  "load_in_4bit": true,
  "load_in_8bit": false,
  "quant_method": "bitsandbytes"
}
[0m
The argument `trust_remote_code` is to be used wi

trainable params: 33,562,626 || all params: 7,024,693,252 || trainable%: 0.4777806630979137


In [3]:
text = ['你是谁', '这个多少钱', 'what is wrong with u']
labels = [0, 1, 0]
labels = torch.tensor(labels).long()
message = template['user_format'].format(content = text)
encode = tokenizer.batch_encode_plus(text, return_tensors='pt', padding='longest', max_length=1024)
encode['prompt_lengths'] = (encode.input_ids != tokenizer.pad_token_id).sum(axis=1)
encode['labels'] = labels
for k, v in encode.items():
    encode[k] = v.to(model.device)
encode2 = dict((k, v) for k, v in encode.items() if k != 'prompt_lengths')

In [4]:
o = model(return_dict=True, **encode)
o

1.8710370063781738, True
SequenceClassifierOutput(loss=tensor(1.8710, device='cuda:0', grad_fn=<NllLossBackward0>), logits=tensor([[-0.2901,  0.7803],
        [ 1.7586, -2.3720],
        [-0.5060, -2.7436]], device='cuda:0', grad_fn=<AddmmBackward0>), hidden_states=None, attentions=None)


SequenceClassifierOutput(loss=tensor(1.8710, device='cuda:0', grad_fn=<NllLossBackward0>), logits=tensor([[-0.2901,  0.7803],
        [ 1.7586, -2.3720],
        [-0.5060, -2.7436]], device='cuda:0', grad_fn=<AddmmBackward0>), hidden_states=None, attentions=None)