In [1]:
# %env HF_ENDPOINT=https://hf-mirror.com
%env HF_HOME=/root/autodl-tmp/hf
%env HF_HUB_CACHE=/root/autodl-tmp/hf

env: HF_HOME=/root/autodl-tmp/hf
env: HF_HUB_CACHE=/root/autodl-tmp/hf


In [2]:
import os
import subprocess

result = subprocess.run('bash -c "source /etc/network_turbo && env | grep proxy"', shell=True, capture_output=True,
                        text=True)
output = result.stdout
for line in output.splitlines():
    if '=' in line:
        var, value = line.split('=', 1)
        os.environ[var] = value

In [3]:
# !pip3 install -q -U bitsandbytes
# !pip3 install -q -U peft
# !pip3 install -q -U trl
# !pip3 install -q -U accelerate
# !pip3 install -q -U datasets
# !pip3 install -q -U transformers

In [4]:
import torch
from transformers import AutoTokenizer, AutoModelForCausalLM, BitsAndBytesConfig


bnb_config = BitsAndBytesConfig(
    load_in_4bit=True,
    bnb_4bit_use_double_quant=True,
    bnb_4bit_quant_type="nf4",
    bnb_4bit_compute_dtype=torch.bfloat16
)


In [5]:
# from huggingface_hub import notebook_login
# notebook_login()

In [6]:
model_id = "google/gemma-2-9b"

model = AutoModelForCausalLM.from_pretrained(model_id, quantization_config=bnb_config, device_map={"":0}, attn_implementation='eager')

Loading checkpoint shards:   0%|          | 0/8 [00:00<?, ?it/s]

In [7]:
tokenizer = AutoTokenizer.from_pretrained(model_id, add_eos_token=True)

In [8]:
from datasets import load_dataset
dataset = load_dataset("json", data_files="synthetic_data_merge.jsonl", split='train')

In [9]:
dataset

Dataset({
    features: ['question', 'answer'],
    num_rows: 647
})

In [10]:
dataset.to_pandas().tail(3)

Unnamed: 0,question,answer
644,合轴流程的主要目的是什么？,合轴流程主要目的是在单位时间内提高团队输出效率。它通过穿插更多角色的操作来利用长脱手动作导致...
645,在合轴过程中，哪些技能会被用到？,合轴流程中会用到声骸技能（Q）、常态攻击（A）、共鸣技能（E）、重击（Z）、共鸣解放（R）和...
646,给出一个合轴流程的例子？,比如推荐配队中，卡卡罗的技能循环是QEAEAEAZ，维里奈是QTE满后Q，吟霖是EAAAAA...


In [11]:
def generate_prompt(data_point):
    text = f"""问题: {data_point["question"]}\n 回答: {data_point["answer"]}"""
    return text

In [12]:
# add the "prompt" column in the dataset
text_column = [generate_prompt(data_point) for data_point in dataset]
dataset = dataset.add_column("prompt", text_column)

dataset = dataset.shuffle(seed=1234)  # Shuffle dataset here
dataset = dataset.map(lambda samples: tokenizer(samples["prompt"]), batched=True)

In [13]:
split = dataset.train_test_split(test_size=0.05)
train_data = split["train"]
test_data = split["test"]

In [14]:
train_data

Dataset({
    features: ['question', 'answer', 'prompt', 'input_ids', 'attention_mask'],
    num_rows: 614
})

In [15]:
test_data

Dataset({
    features: ['question', 'answer', 'prompt', 'input_ids', 'attention_mask'],
    num_rows: 33
})

In [16]:
test_data.to_pandas().head(3)

Unnamed: 0,question,answer,prompt,input_ids,attention_mask
0,“合轴流程合轴”中如何操作配队今汐+吟霖+维里奈？,具体输出手法包括：以配队今汐+吟霖+维里奈为例：吟霖ZRZEZEAAZ（重击后摇靠技能打断，...,问题: “合轴流程合轴”中如何操作配队今汐+吟霖+维里奈？\n 回答: 具体输出手法包括：以...,"[2, 13240, 235292, 1080, 235697, 239302, 54201...","[1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, ..."
1,玩家在游戏中需要与什么生物对抗？,玩家需要与从末日灾难中诞生的怪物“幻象”对抗。,问题: 玩家在游戏中需要与什么生物对抗？\n 回答: 玩家需要与从末日灾难中诞生的怪物“幻象...,"[2, 13240, 235292, 235248, 34673, 235473, 1585...","[1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, ..."
2,在合轴过程中，哪些技能会被用到？,合轴流程中会用到声骸技能（Q）、常态攻击（A）、共鸣技能（E）、重击（Z）、共鸣解放（R）和...,问题: 在合轴过程中，哪些技能会被用到？\n 回答: 合轴流程中会用到声骸技能（Q）、常态攻...,"[2, 13240, 235292, 19183, 235697, 239302, 6393...","[1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, ..."


In [17]:
from peft import LoraConfig, PeftModel, prepare_model_for_kbit_training, get_peft_model
model = prepare_model_for_kbit_training(model)

In [18]:
import bitsandbytes as bnb
def find_all_linear_names(model):
  cls = bnb.nn.Linear4bit #if args.bits == 4 else (bnb.nn.Linear8bitLt if args.bits == 8 else torch.nn.Linear)
  lora_module_names = set()
  for name, module in model.named_modules():
    if isinstance(module, cls):
      names = name.split('.')
      lora_module_names.add(names[0] if len(names) == 1 else names[-1])
    if 'lm_head' in lora_module_names: # needed for 16-bit
      lora_module_names.remove('lm_head')
  return list(lora_module_names)

In [19]:
modules = find_all_linear_names(model)
print(modules)

['v_proj', 'o_proj', 'gate_proj', 'up_proj', 'q_proj', 'k_proj', 'down_proj']


In [20]:
from peft import LoraConfig, get_peft_model

lora_config = LoraConfig(
    r=64,
    lora_alpha=32,
    target_modules=modules,
    lora_dropout=0.05,
    bias="none",
    task_type="CAUSAL_LM"
)

model = get_peft_model(model, lora_config)

In [21]:
trainable, total = model.get_nb_trainable_parameters()
print(f"Trainable: {trainable} | total: {total} | Percentage: {trainable/total*100:.4f}%")

Trainable: 216072192 | total: 9457778176 | Percentage: 2.2846%


In [22]:
import transformers
import os

from trl import SFTTrainer, SFTConfig


tokenizer.pad_token = tokenizer.eos_token
torch.cuda.empty_cache()

save_dir = '/root/autodl-tmp/models'
save_name = "google/gemma-2-9b-mingchao-ft" 
out_dir = os.path.join(save_dir, save_name)

trainer = SFTTrainer(
    model=model,
    train_dataset=train_data,
    eval_dataset=test_data,
    peft_config=lora_config,
    args=SFTConfig(
        overwrite_output_dir=True,
        per_device_train_batch_size=4,
        per_device_eval_batch_size=2,
        gradient_accumulation_steps=2,
        eval_accumulation_steps=1,
        warmup_ratio=0.1,
        learning_rate=2e-5,
        logging_strategy="steps",
        logging_steps=0.05,
        output_dir=out_dir,
        optim="paged_adamw_32bit",
        eval_strategy="no",
        eval_steps=0.2,
        save_strategy="steps",
        save_steps=0.1,
        report_to="tensorboard",
        num_train_epochs=3,
        max_seq_length=512,
        dataset_text_field="prompt",
    ),
    data_collator=transformers.DataCollatorForLanguageModeling(tokenizer, mlm=False),
)



In [23]:
trainer.train(resume_from_checkpoint = True)

	logging_steps: 0.05 (from args) != 12 (from trainer_state.json)
	eval_steps: 0.2 (from args) != 47 (from trainer_state.json)
	save_steps: 0.1 (from args) != 24 (from trainer_state.json)
`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`.


Step,Training Loss
228,1.0178


TrainOutput(global_step=231, training_loss=0.06589622105354871, metrics={'train_runtime': 26.2329, 'train_samples_per_second': 70.217, 'train_steps_per_second': 8.806, 'total_flos': 5525899002839040.0, 'train_loss': 0.06589622105354871, 'epoch': 3.0})

In [24]:
trainer.save_model()
trainer.save_state()