In [None]:
# 当前基础环境：py310-torch231-cu121

# 在基础环境之上安装依赖
! pip install swanlab modelscope transformers datasets peft accelerate

In [2]:
from modelscope import MsDataset

# 假设你已经正确加载了 MsDataset 数据集
dataset = MsDataset.load('swift/zh_cls_fudan-news', split='train')
test_dataset = MsDataset.load('swift/zh_cls_fudan-news', subset_name='test', split='test')

print(dataset)
print(test_dataset)

Dataset({
    features: ['text', 'category', 'output'],
    num_rows: 4000
})
Dataset({
    features: ['text', 'category', 'output'],
    num_rows: 959
})


In [2]:
import json

# 转换数据
def construct_instruction_dataset(dataset, output_file):
    messages = []
    
    for data in dataset:  # 遍历加载的数据集
        context = data["text"]      # 获取文本内容
        category = data["category"]  # 获取分类选项
        output = data["output"]         # 获取标签（输出）
        
        # 构造指令格式的数据
        message = {
            "instruction": "作为文本分类专家，请根据给定的文本和分类选项，确定正确的文本类型。",
            "input": f"文本: {context}\n分类选项: {category}",
            "output": output
        }
        messages.append(message)

    # 写入jsonl文件
    with open(output_file, 'w', encoding='utf-8') as file:
        for message in messages:
            file.write(json.dumps(message, ensure_ascii=False) + '\n')

# 执行转换
construct_instruction_dataset(dataset, 'instruction_dataset.jsonl')
construct_instruction_dataset(test_dataset, 'instruction_dataset_test.jsonl')
print("转换完成")

转换完成


In [3]:
# 加载数据集
from datasets import Dataset
import pandas as pd

df = pd.read_json('./instruction_dataset.jsonl', lines=True)
ds = Dataset.from_pandas(df)


In [5]:
# 下载模型
import torch
from modelscope import snapshot_download

model_dir = snapshot_download('qwen/Qwen1.5-7B-Chat', cache_dir='D:/.cache/modelscope/', revision='master')


Downloading [config.json]: 100%|██████████| 663/663 [00:00<00:00, 1.82kB/s]
Downloading [configuration.json]: 100%|██████████| 51.0/51.0 [00:00<00:00, 150B/s]
Downloading [generation_config.json]: 100%|██████████| 243/243 [00:00<00:00, 713B/s]
Downloading [LICENSE]: 100%|██████████| 6.73k/6.73k [00:00<00:00, 19.4kB/s]
Downloading [merges.txt]: 100%|██████████| 1.59M/1.59M [00:05<00:00, 328kB/s]
Downloading [model-00001-of-00004.safetensors]: 3.74GB [05:22, 12.5MB/s]
Downloading [model-00002-of-00004.safetensors]: 100%|██████████| 3.69G/3.69G [01:54<00:00, 34.5MB/s]
Downloading [model-00003-of-00004.safetensors]: 100%|██████████| 3.69G/3.69G [02:09<00:00, 30.6MB/s]
Downloading [model-00004-of-00004.safetensors]: 100%|██████████| 3.30G/3.30G [01:25<00:00, 41.3MB/s]
Downloading [model.safetensors.index.json]: 100%|██████████| 31.0k/31.0k [00:00<00:00, 67.0kB/s]
Downloading [README.md]: 100%|██████████| 4.15k/4.15k [00:00<00:00, 9.56kB/s]
Downloading [tokenizer.json]: 100%|██████████| 6.70

In [8]:
# 加载模型
from transformers import AutoModelForCausalLM, AutoTokenizer

tokenizer = AutoTokenizer.from_pretrained('D:/.cache/modelscope/qwen/Qwen1___5-7B-Chat/', use_fast=False, trust_remote_code=True)
model = AutoModelForCausalLM.from_pretrained('D:/.cache/modelscope/qwen/Qwen1___5-7B-Chat/', device_map="auto", torch_dtype=torch.bfloat16)
model.enable_input_require_grads()


A module that was compiled using NumPy 1.x cannot be run in
NumPy 2.1.0 as it may crash. To support both 1.x and 2.x
versions of NumPy, modules must be compiled with NumPy 2.0.
Some module may need to rebuild instead e.g. with 'pybind11>=2.12'.

If you are a user of the module, the easiest solution will be to
downgrade to 'numpy<2' or try to upgrade the affected module.
We expect that some modules will need time to support NumPy 2.

Traceback (most recent call last):  File "d:\miniconda3\envs\test_lora\lib\runpy.py", line 196, in _run_module_as_main
    return _run_code(code, main_globals, None,
  File "d:\miniconda3\envs\test_lora\lib\runpy.py", line 86, in _run_code
    exec(code, run_globals)
  File "d:\miniconda3\envs\test_lora\lib\site-packages\ipykernel_launcher.py", line 18, in <module>
    app.launch_new_instance()
  File "d:\miniconda3\envs\test_lora\lib\site-packages\traitlets\config\application.py", line 1075, in launch_instance
    app.start()
  File "d:\miniconda3\envs\te

In [9]:
# 数据预处理
def preprocess_function(example):
    MAX_LENGTH = 384
    prompt = f"<|im_start|>system\n{example['instruction']}<|im_end|>\n<|im_start|>user\n{example['input']}<|im_end|>\n<|im_start|>assistant\n"
    
    inputs = tokenizer(prompt, add_special_tokens=False)
    outputs = tokenizer(f"{example['output']}", add_special_tokens=False)
    
    input_ids = inputs["input_ids"] + outputs["input_ids"] + [tokenizer.eos_token_id]
    attention_mask = inputs["attention_mask"] + outputs["attention_mask"] + [1]
    labels = [-100] * len(inputs["input_ids"]) + outputs["input_ids"] + [tokenizer.eos_token_id]
    
    if len(input_ids) > MAX_LENGTH:
        input_ids = input_ids[:MAX_LENGTH]
        attention_mask = attention_mask[:MAX_LENGTH]
        labels = labels[:MAX_LENGTH]
    
    return {
        "input_ids": input_ids,
        "attention_mask": attention_mask,
        "labels": labels
    }

tokenized_dataset = ds.map(preprocess_function, remove_columns=ds.column_names)

Map: 100%|██████████| 4000/4000 [15:07<00:00,  4.41 examples/s]


In [10]:
# 设置lora参数
from peft import LoraConfig, TaskType

lora_config = LoraConfig(
    task_type=TaskType.CAUSAL_LM,
    target_modules=["q_proj", "k_proj", "v_proj", "o_proj", "gate_proj", "up_proj", "down_proj"],
    inference_mode=False,
    r=8,
    lora_alpha=32,
    lora_dropout=0.1
)

In [11]:
# 应用lora配置
from peft import get_peft_model

model = get_peft_model(model, lora_config)

In [12]:
# 设置训练参数
from transformers import TrainingArguments

training_args = TrainingArguments(
    output_dir="./output/Qwen1.5",
    per_device_train_batch_size=4,
    gradient_accumulation_steps=4,
    logging_steps=10,
    num_train_epochs=3,
    save_steps=100,
    learning_rate=1e-4,
    save_on_each_node=True,
    gradient_checkpointing=True
)

In [13]:
# 使用swanlab监控
from swanlab.integration.huggingface import SwanLabCallback

swanlab_callback = SwanLabCallback(project="Qwen1.5-Finetune")

In [14]:
# 开始训练
from transformers import Trainer, DataCollatorForSeq2Seq

trainer = Trainer(
    model=model,
    args=training_args,
    train_dataset=tokenized_dataset,
    data_collator=DataCollatorForSeq2Seq(tokenizer=tokenizer, padding=True),
    callbacks=[swanlab_callback],
)

trainer.train()

[1m[34mswanlab[0m[0m: Tracking run with swanlab version 0.3.19                                  
[1m[34mswanlab[0m[0m: Run data will be saved locally in [35m[1md:\Space\PRO\test\test_lora\swanlog\run-20240910_045709-a3b1799d[0m[0m
[1m[34mswanlab[0m[0m: 👋 Hi [1m[39msamge[0m[0m, welcome to swanlab!
[1m[34mswanlab[0m[0m: Syncing run [33mexp_Sep10_04-57-09[0m to the cloud
[1m[34mswanlab[0m[0m: 🌟 Run `[1mswanlab watch d:\Space\PRO\test\test_lora\swanlog[0m` to view SwanLab Experiment Dashboard locally
[1m[34mswanlab[0m[0m: 🏠 View project at [34m[4mhttps://swanlab.cn/@samge/Qwen1.5-Finetune[0m[0m
[1m[34mswanlab[0m[0m: 🚀 View run at [34m[4mhttps://swanlab.cn/@samge/Qwen1.5-Finetune/runs/pe3331lvt83drqrsx40n1[0m[0m


  0%|          | 0/750 [00:00<?, ?it/s]`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`...
  attn_output = torch.nn.functional.scaled_dot_product_attention(
  1%|▏         | 10/750 [01:11<1:23:43,  6.79s/it]

{'loss': 15.2816, 'grad_norm': 0.3484857380390167, 'learning_rate': 9.866666666666668e-05, 'epoch': 0.04}


  3%|▎         | 20/750 [02:19<1:22:34,  6.79s/it]

{'loss': 0.0086, 'grad_norm': 0.011960361152887344, 'learning_rate': 9.733333333333335e-05, 'epoch': 0.08}


  4%|▍         | 30/750 [03:26<1:21:15,  6.77s/it]

{'loss': 17.1864, 'grad_norm': 12.014134407043457, 'learning_rate': 9.6e-05, 'epoch': 0.12}


  5%|▌         | 40/750 [04:34<1:20:26,  6.80s/it]

{'loss': 0.2126, 'grad_norm': 5.828611850738525, 'learning_rate': 9.466666666666667e-05, 'epoch': 0.16}


  7%|▋         | 50/750 [05:42<1:19:37,  6.83s/it]

{'loss': 0.1479, 'grad_norm': 8.789040565490723, 'learning_rate': 9.333333333333334e-05, 'epoch': 0.2}


  8%|▊         | 60/750 [06:50<1:17:55,  6.78s/it]

{'loss': 16.6769, 'grad_norm': 0.18294575810432434, 'learning_rate': 9.200000000000001e-05, 'epoch': 0.24}


  9%|▉         | 70/750 [07:58<1:17:13,  6.81s/it]

{'loss': 1.2039, 'grad_norm': 0.013400505296885967, 'learning_rate': 9.066666666666667e-05, 'epoch': 0.28}


 11%|█         | 80/750 [09:07<1:16:02,  6.81s/it]

{'loss': 2.0064, 'grad_norm': 0.01734144426882267, 'learning_rate': 8.933333333333334e-05, 'epoch': 0.32}


 12%|█▏        | 90/750 [10:15<1:14:46,  6.80s/it]

{'loss': 0.3253, 'grad_norm': 0.0040741898119449615, 'learning_rate': 8.800000000000001e-05, 'epoch': 0.36}


 13%|█▎        | 100/750 [11:22<1:13:15,  6.76s/it]

{'loss': 2.3989, 'grad_norm': 0.0, 'learning_rate': 8.666666666666667e-05, 'epoch': 0.4}


 15%|█▍        | 110/750 [12:31<1:12:59,  6.84s/it]

{'loss': 5.967, 'grad_norm': 1.4720704555511475, 'learning_rate': 8.533333333333334e-05, 'epoch': 0.44}


 16%|█▌        | 120/750 [13:40<1:11:27,  6.81s/it]

{'loss': 0.5107, 'grad_norm': 0.0007368105580098927, 'learning_rate': 8.4e-05, 'epoch': 0.48}


 17%|█▋        | 130/750 [14:47<1:10:08,  6.79s/it]

{'loss': 0.2246, 'grad_norm': 0.004771663341671228, 'learning_rate': 8.266666666666667e-05, 'epoch': 0.52}


 19%|█▊        | 140/750 [15:55<1:09:03,  6.79s/it]

{'loss': 0.0029, 'grad_norm': 0.00966606941074133, 'learning_rate': 8.133333333333334e-05, 'epoch': 0.56}


 20%|██        | 150/750 [17:03<1:07:51,  6.79s/it]

{'loss': 1.0327, 'grad_norm': 11.649827003479004, 'learning_rate': 8e-05, 'epoch': 0.6}


 21%|██▏       | 160/750 [18:11<1:07:04,  6.82s/it]

{'loss': 0.0536, 'grad_norm': 0.007207191549241543, 'learning_rate': 7.866666666666666e-05, 'epoch': 0.64}


 23%|██▎       | 170/750 [19:20<1:05:41,  6.80s/it]

{'loss': 0.0806, 'grad_norm': 0.028032319620251656, 'learning_rate': 7.733333333333333e-05, 'epoch': 0.68}


 24%|██▍       | 180/750 [20:28<1:04:38,  6.80s/it]

{'loss': 1.1424, 'grad_norm': 0.00332365813665092, 'learning_rate': 7.6e-05, 'epoch': 0.72}


 25%|██▌       | 190/750 [21:35<1:03:17,  6.78s/it]

{'loss': 0.0021, 'grad_norm': 0.009555951692163944, 'learning_rate': 7.466666666666667e-05, 'epoch': 0.76}


 27%|██▋       | 200/750 [22:43<1:02:26,  6.81s/it]

{'loss': 6.4357, 'grad_norm': 25.89899253845215, 'learning_rate': 7.333333333333333e-05, 'epoch': 0.8}


 28%|██▊       | 210/750 [23:53<1:01:40,  6.85s/it]

{'loss': 1.2124, 'grad_norm': 0.2879907488822937, 'learning_rate': 7.2e-05, 'epoch': 0.84}


 29%|██▉       | 220/750 [25:00<1:00:01,  6.80s/it]

{'loss': 0.0071, 'grad_norm': 0.010140646249055862, 'learning_rate': 7.066666666666667e-05, 'epoch': 0.88}


 31%|███       | 230/750 [26:08<58:51,  6.79s/it]  

{'loss': 0.0057, 'grad_norm': 0.010391815565526485, 'learning_rate': 6.933333333333334e-05, 'epoch': 0.92}


 32%|███▏      | 240/750 [27:17<57:58,  6.82s/it]

{'loss': 3.6255, 'grad_norm': 0.016633223742246628, 'learning_rate': 6.800000000000001e-05, 'epoch': 0.96}


 33%|███▎      | 250/750 [28:25<56:48,  6.82s/it]

{'loss': 1.7426, 'grad_norm': 0.023199576884508133, 'learning_rate': 6.666666666666667e-05, 'epoch': 1.0}


 35%|███▍      | 260/750 [29:33<55:22,  6.78s/it]

{'loss': 0.4168, 'grad_norm': 0.0, 'learning_rate': 6.533333333333334e-05, 'epoch': 1.04}


 36%|███▌      | 270/750 [30:41<54:29,  6.81s/it]

{'loss': 0.2702, 'grad_norm': 0.18347449600696564, 'learning_rate': 6.400000000000001e-05, 'epoch': 1.08}


 37%|███▋      | 280/750 [31:49<53:08,  6.78s/it]

{'loss': 0.3036, 'grad_norm': 0.0, 'learning_rate': 6.266666666666667e-05, 'epoch': 1.12}


 39%|███▊      | 290/750 [32:57<52:04,  6.79s/it]

{'loss': 1.332, 'grad_norm': 1.1467218399047852, 'learning_rate': 6.133333333333334e-05, 'epoch': 1.16}


 40%|████      | 300/750 [34:05<51:10,  6.82s/it]

{'loss': 0.005, 'grad_norm': 0.4076959788799286, 'learning_rate': 6e-05, 'epoch': 1.2}


 41%|████▏     | 310/750 [35:14<50:11,  6.84s/it]

{'loss': 0.0302, 'grad_norm': 0.0033045571763068438, 'learning_rate': 5.866666666666667e-05, 'epoch': 1.24}


 43%|████▎     | 320/750 [36:22<48:48,  6.81s/it]

{'loss': 0.0532, 'grad_norm': 8.075716972351074, 'learning_rate': 5.7333333333333336e-05, 'epoch': 1.28}


 44%|████▍     | 330/750 [37:30<47:33,  6.79s/it]

{'loss': 3.7878, 'grad_norm': 0.003436617087572813, 'learning_rate': 5.6000000000000006e-05, 'epoch': 1.32}


 45%|████▌     | 340/750 [38:38<46:30,  6.81s/it]

{'loss': 0.0009, 'grad_norm': 0.008700637146830559, 'learning_rate': 5.466666666666666e-05, 'epoch': 1.36}


 47%|████▋     | 350/750 [39:46<45:15,  6.79s/it]

{'loss': 0.009, 'grad_norm': 0.0005118648405186832, 'learning_rate': 5.333333333333333e-05, 'epoch': 1.4}


 48%|████▊     | 360/750 [40:54<44:17,  6.81s/it]

{'loss': 0.019, 'grad_norm': 0.008285238407552242, 'learning_rate': 5.2000000000000004e-05, 'epoch': 1.44}


 49%|████▉     | 370/750 [42:02<43:07,  6.81s/it]

{'loss': 0.0053, 'grad_norm': 0.0015612270217388868, 'learning_rate': 5.0666666666666674e-05, 'epoch': 1.48}


 51%|█████     | 380/750 [43:10<41:55,  6.80s/it]

{'loss': 0.0012, 'grad_norm': 0.008719929493963718, 'learning_rate': 4.933333333333334e-05, 'epoch': 1.52}


 52%|█████▏    | 390/750 [44:18<40:41,  6.78s/it]

{'loss': 0.7008, 'grad_norm': 0.001449528499506414, 'learning_rate': 4.8e-05, 'epoch': 1.56}


 53%|█████▎    | 400/750 [45:26<39:39,  6.80s/it]

{'loss': 1.2397, 'grad_norm': 0.006278482265770435, 'learning_rate': 4.666666666666667e-05, 'epoch': 1.6}


 55%|█████▍    | 410/750 [46:35<38:40,  6.83s/it]

{'loss': 0.0294, 'grad_norm': 9.826807022094727, 'learning_rate': 4.5333333333333335e-05, 'epoch': 1.64}


 56%|█████▌    | 420/750 [47:43<37:30,  6.82s/it]

{'loss': 0.1536, 'grad_norm': 0.0018989763921126723, 'learning_rate': 4.4000000000000006e-05, 'epoch': 1.68}


 57%|█████▋    | 430/750 [48:51<36:22,  6.82s/it]

{'loss': 1.2261, 'grad_norm': 0.00921633094549179, 'learning_rate': 4.266666666666667e-05, 'epoch': 1.72}


 59%|█████▊    | 440/750 [49:59<35:08,  6.80s/it]

{'loss': 0.1244, 'grad_norm': 0.002830774523317814, 'learning_rate': 4.133333333333333e-05, 'epoch': 1.76}


 60%|██████    | 450/750 [51:07<33:56,  6.79s/it]

{'loss': 0.0258, 'grad_norm': 0.0014379501808434725, 'learning_rate': 4e-05, 'epoch': 1.8}


 61%|██████▏   | 460/750 [52:15<32:51,  6.80s/it]

{'loss': 0.0473, 'grad_norm': 0.011072887107729912, 'learning_rate': 3.866666666666667e-05, 'epoch': 1.84}


 63%|██████▎   | 470/750 [53:23<31:36,  6.77s/it]

{'loss': 0.0106, 'grad_norm': 0.0, 'learning_rate': 3.733333333333334e-05, 'epoch': 1.88}


 64%|██████▍   | 480/750 [54:31<30:41,  6.82s/it]

{'loss': 0.0195, 'grad_norm': 0.008419172838330269, 'learning_rate': 3.6e-05, 'epoch': 1.92}


 65%|██████▌   | 490/750 [55:39<29:26,  6.80s/it]

{'loss': 0.3243, 'grad_norm': 0.0, 'learning_rate': 3.466666666666667e-05, 'epoch': 1.96}


 67%|██████▋   | 500/750 [56:47<28:28,  6.83s/it]

{'loss': 0.0234, 'grad_norm': 0.013071675784885883, 'learning_rate': 3.3333333333333335e-05, 'epoch': 2.0}


 68%|██████▊   | 510/750 [57:56<27:09,  6.79s/it]

{'loss': 0.0446, 'grad_norm': 0.0010200308170169592, 'learning_rate': 3.2000000000000005e-05, 'epoch': 2.04}


 69%|██████▉   | 520/750 [59:04<26:04,  6.80s/it]

{'loss': 0.881, 'grad_norm': 0.004574589431285858, 'learning_rate': 3.066666666666667e-05, 'epoch': 2.08}


 71%|███████   | 530/750 [1:00:12<24:53,  6.79s/it]

{'loss': 0.0013, 'grad_norm': 0.0012571322731673717, 'learning_rate': 2.9333333333333336e-05, 'epoch': 2.12}


 72%|███████▏  | 540/750 [1:01:20<23:44,  6.78s/it]

{'loss': 0.0005, 'grad_norm': 0.0013010511174798012, 'learning_rate': 2.8000000000000003e-05, 'epoch': 2.16}


 73%|███████▎  | 550/750 [1:02:28<22:36,  6.78s/it]

{'loss': 0.002, 'grad_norm': 0.0020272042602300644, 'learning_rate': 2.6666666666666667e-05, 'epoch': 2.2}


 75%|███████▍  | 560/750 [1:03:36<21:34,  6.81s/it]

{'loss': 0.003, 'grad_norm': 0.0024826866574585438, 'learning_rate': 2.5333333333333337e-05, 'epoch': 2.24}


 76%|███████▌  | 570/750 [1:04:44<20:21,  6.79s/it]

{'loss': 0.0025, 'grad_norm': 0.058298878371715546, 'learning_rate': 2.4e-05, 'epoch': 2.28}


 77%|███████▋  | 580/750 [1:05:52<19:17,  6.81s/it]

{'loss': 0.0004, 'grad_norm': 0.002474732929840684, 'learning_rate': 2.2666666666666668e-05, 'epoch': 2.32}


 79%|███████▊  | 590/750 [1:07:00<18:12,  6.83s/it]

{'loss': 0.0246, 'grad_norm': 0.014275957830250263, 'learning_rate': 2.1333333333333335e-05, 'epoch': 2.36}


 80%|████████  | 600/750 [1:08:08<16:58,  6.79s/it]

{'loss': 0.0025, 'grad_norm': 0.0018427488394081593, 'learning_rate': 2e-05, 'epoch': 2.4}


 81%|████████▏ | 610/750 [1:09:17<15:47,  6.77s/it]

{'loss': 0.021, 'grad_norm': 0.0, 'learning_rate': 1.866666666666667e-05, 'epoch': 2.44}


 83%|████████▎ | 620/750 [1:10:25<14:46,  6.82s/it]

{'loss': 0.0197, 'grad_norm': 0.0007808062946423888, 'learning_rate': 1.7333333333333336e-05, 'epoch': 2.48}


 84%|████████▍ | 630/750 [1:11:33<13:35,  6.80s/it]

{'loss': 0.0041, 'grad_norm': 0.0011206006165593863, 'learning_rate': 1.6000000000000003e-05, 'epoch': 2.52}


 85%|████████▌ | 640/750 [1:12:41<12:29,  6.82s/it]

{'loss': 0.0222, 'grad_norm': 0.001649277750402689, 'learning_rate': 1.4666666666666668e-05, 'epoch': 2.56}


 87%|████████▋ | 650/750 [1:13:49<11:18,  6.78s/it]

{'loss': 32.9749, 'grad_norm': 0.0, 'learning_rate': 1.3333333333333333e-05, 'epoch': 2.6}


 88%|████████▊ | 660/750 [1:14:57<10:08,  6.76s/it]

{'loss': 0.0019, 'grad_norm': 0.0, 'learning_rate': 1.2e-05, 'epoch': 2.64}


 89%|████████▉ | 670/750 [1:16:05<09:02,  6.79s/it]

{'loss': 0.002, 'grad_norm': 0.0017756301676854491, 'learning_rate': 1.0666666666666667e-05, 'epoch': 2.68}


 91%|█████████ | 680/750 [1:17:13<07:55,  6.79s/it]

{'loss': 0.0139, 'grad_norm': 0.0, 'learning_rate': 9.333333333333334e-06, 'epoch': 2.72}


 92%|█████████▏| 690/750 [1:18:21<06:48,  6.81s/it]

{'loss': 0.0026, 'grad_norm': 0.0014191354857757688, 'learning_rate': 8.000000000000001e-06, 'epoch': 2.76}


 93%|█████████▎| 700/750 [1:19:29<05:39,  6.78s/it]

{'loss': 0.0072, 'grad_norm': 0.0031364934984594584, 'learning_rate': 6.666666666666667e-06, 'epoch': 2.8}


 95%|█████████▍| 710/750 [1:20:38<04:33,  6.84s/it]

{'loss': 0.1089, 'grad_norm': 0.0011811187723651528, 'learning_rate': 5.333333333333334e-06, 'epoch': 2.84}


 96%|█████████▌| 720/750 [1:21:45<03:22,  6.76s/it]

{'loss': 0.0803, 'grad_norm': 0.0, 'learning_rate': 4.000000000000001e-06, 'epoch': 2.88}


 97%|█████████▋| 730/750 [1:22:53<02:15,  6.77s/it]

{'loss': 0.0022, 'grad_norm': 0.0, 'learning_rate': 2.666666666666667e-06, 'epoch': 2.92}


 99%|█████████▊| 740/750 [1:24:02<01:08,  6.83s/it]

{'loss': 0.004, 'grad_norm': 0.00223860633559525, 'learning_rate': 1.3333333333333334e-06, 'epoch': 2.96}


100%|██████████| 750/750 [1:25:10<00:00,  6.83s/it]

{'loss': 0.045, 'grad_norm': 0.0019128503045067191, 'learning_rate': 0.0, 'epoch': 3.0}


100%|██████████| 750/750 [1:25:11<00:00,  6.82s/it]

[1m[33mswanlab[0m[0m: Step 750 on key train/epoch already exists, ignored.
{'train_runtime': 5121.1254, 'train_samples_per_second': 2.343, 'train_steps_per_second': 0.146, 'train_loss': 1.625669944545254, 'epoch': 3.0}





TrainOutput(global_step=750, training_loss=1.625669944545254, metrics={'train_runtime': 5121.1254, 'train_samples_per_second': 2.343, 'train_steps_per_second': 0.146, 'total_flos': 1.96825646628864e+17, 'train_loss': 1.625669944545254, 'epoch': 3.0})

In [25]:
# 模型测试
from peft import PeftModel

lora_path = 'output/Qwen1.5/checkpoint-700'
model = PeftModel.from_pretrained(model, model_id=lora_path)

def predict(text, category_options):
    prompt = f"文本:{text},类型选项:{category_options}"
    messages = [
        {"role": "system", "content": "你是一个文本分类领域的专家，请根据给定的文本和分类选项，确定正确的文本类型。不需要解释，直接输出预测类型，如果没有命中的类型，直接返回：其他。"},
        {"role": "user", "content": prompt}
    ]
    
    input_text = tokenizer.apply_chat_template(messages, tokenize=False, add_generation_prompt=True)
    inputs = tokenizer([input_text], return_tensors="pt").to('cuda')
    
    with torch.no_grad():
        outputs = model.generate(
            inputs.input_ids,
            max_new_tokens=512
        )
    
    response = tokenizer.decode(outputs[0][len(inputs.input_ids[0]):], skip_special_tokens=True)
    return response.strip()

# 使用示例
test_text = "马斯克：20年内容建立火星基地。"
test_categories = "体育,科技,娱乐,政治"
result = predict(test_text, test_categories)
print(f"预测结果: {result}")

预测结果: 科技
