In [1]:
import os, torch
from transformers import AutoModelForCausalLM
from peft import LoraConfig, TaskType, get_peft_model

# 1. 只让内核看到 4 号和 5 号卡
os.environ["CUDA_VISIBLE_DEVICES"] = "4"      # 写在最前面！

model_path = '/data/lxy/Qwen/qwen2.5-7b-base'

# 2. 加载时交给 accelerate 自动拆
model = AutoModelForCausalLM.from_pretrained(
    model_path,
    device_map="auto",        # 关键！
    torch_dtype=torch.bfloat16
)

peft_config = LoraConfig(
    r=32,
    lora_alpha=64,
    task_type=TaskType.CAUSAL_LM,
    target_modules='all-linear'
)
model = get_peft_model(model, peft_config)
model.print_trainable_parameters()

  from .autonotebook import tqdm as notebook_tqdm
`torch_dtype` is deprecated! Use `dtype` instead!
Loading checkpoint shards: 100%|██████████| 4/4 [00:01<00:00,  2.37it/s]


trainable params: 80,740,352 || all params: 7,696,356,864 || trainable%: 1.0491


In [2]:
import os, torch
from transformers import AutoModelForCausalLM
from peft import LoraConfig, TaskType, get_peft_model

# 1. 只让内核看到 4 号和 5 号卡
os.environ["CUDA_VISIBLE_DEVICES"] = "5"      # 写在最前面！

model_path = '/data/lxy/Qwen/qwen3-8b-base'

# 2. 加载时交给 accelerate 自动拆
model = AutoModelForCausalLM.from_pretrained(
    model_path,
    device_map="auto",        # 关键！
    torch_dtype=torch.bfloat16
)

peft_config = LoraConfig(
    r=32,
    lora_alpha=64,
    task_type=TaskType.CAUSAL_LM,
    target_modules='all-linear'
)
model = get_peft_model(model, peft_config)
model.print_trainable_parameters()

Loading checkpoint shards: 100%|██████████| 5/5 [00:01<00:00,  3.11it/s]
Some parameters are on the meta device because they were offloaded to the cpu.


trainable params: 87,293,952 || all params: 8,278,029,312 || trainable%: 1.0545


In [1]:
import os, torch
from transformers import AutoModelForCausalLM,AutoModel,AutoModelForMaskedLM
from peft import LoraConfig, TaskType, get_peft_model

os.environ["CUDA_VISIBLE_DEVICES"] = "4"      # 写在最前面！

model_path='/data/lxy/diffusion/llada-8b'

model=AutoModel.from_pretrained(
    model_path,
    trust_remote_code=True, 
    device_map="auto",        # 关键！
    torch_dtype=torch.bfloat16
)
peft_config = LoraConfig(
    r=32,
    lora_alpha=64,
    task_type=TaskType.CAUSAL_LM,
    target_modules='all-linear'
)
model = get_peft_model(model, peft_config)
model.print_trainable_parameters()

  from .autonotebook import tqdm as notebook_tqdm
`torch_dtype` is deprecated! Use `dtype` instead!
The model weights are not tied. Please use the `tie_weights` method before using the `infer_auto_device` function.
Loading checkpoint shards: 100%|██████████| 6/6 [00:01<00:00,  3.69it/s]


trainable params: 83,886,080 || all params: 8,099,467,264 || trainable%: 1.0357


In [9]:
import transformers
from dataclasses import dataclass, field,asdict


@dataclass
class ModelArguments:
    model_name_or_path: str = "/data/lxy/diffusion/qwen2.5-7b-base"
    dtype: str = "bfloat16"
    # --- fold PEFT args here ---
    lora: bool = False
    target_modules: str = "all-linear"
    r: int = 32
    lora_alpha: int = 64
    lora_dropout: float = 0.05
    bias: str = "none"


@dataclass
class DataArguments:
    dataset_args: str = None # overwrite this
    num_proc: int = 8
    disable_caching: bool = False
    max_length: int = 1024
    truncation: str = field(
        default="right",
        metadata={
            "help": (
                'The truncation strategy to use ("filter" or "right"). '
                '"filter" only keeps sequences that are shorter than max_length; '
                '"right" only keeps the rightmost max_length tokens for each sequence.'
            )
        },
    )
    

@dataclass
class TrainingArguments(transformers.TrainingArguments):
    output_dir: str = None  # overwrite this
    report_to: str = "swanlab"
    run_name: str = "test-1"
    overwrite_output_dir: bool = True
    seed: int = 42
    per_device_train_batch_size: int = 2
    per_device_eval_batch_size: int = 2
    gradient_accumulation_steps: int = 8
    learning_rate: float = 2e-5
    lr_scheduler_type: str = "cosine"
    warmup_ratio: float = 0.1
    bf16: bool = True
    num_train_epochs: float = 6
    logging_steps: float = 10
    eval_on_start: bool = False
    eval_strategy: str = "steps"
    eval_steps: float = 0.25
    save_steps: float = 0.25
    save_only_model: bool = True
    save_total_limit: int = 2



In [10]:

parser = transformers.HfArgumentParser(
    (ModelArguments, DataArguments, TrainingArguments)
)

config_yaml='./configs/qwen2.5-7b-alpaca.yaml'
model_args, data_args, training_args = parser.parse_yaml_file(config_yaml)

print("模型参数：", model_args)
print("数据参数：", data_args)
print("训练参数：", training_args)

model_d=asdict(model_args)
data_d=asdict(data_args)
training_d=asdict(training_args)
print("模型参数字典：", model_d)
print("数据参数字典：", data_d)
print("训练参数字典：", training_d)

模型参数： ModelArguments(model_name_or_path='/data/lxy/diffusion/qwen2.5-7b-base', dtype='bfloat16', lora=True, target_modules='all-linear', r=32, lora_alpha=64, lora_dropout=0.05, bias='none')
数据参数： DataArguments(dataset_args='/data/lxy/diffusion/data/alpaca-zh-gpt[train:2000,test:200]', num_proc=8, disable_caching=False, max_length=1024, truncation='right')
训练参数： TrainingArguments(
_n_gpu=8,
accelerator_config={'split_batches': False, 'dispatch_batches': None, 'even_batches': True, 'use_seedable_sampler': True, 'non_blocking': False, 'gradient_accumulation_kwargs': None, 'use_configured_state': False},
adafactor=False,
adam_beta1=0.9,
adam_beta2=0.999,
adam_epsilon=1e-08,
auto_find_batch_size=False,
average_tokens_across_devices=True,
batch_eval_metrics=False,
bf16=True,
bf16_full_eval=False,
data_seed=None,
dataloader_drop_last=False,
dataloader_num_workers=0,
dataloader_persistent_workers=False,
dataloader_pin_memory=True,
dataloader_prefetch_factor=None,
ddp_backend=None,
ddp_broadcas

In [3]:
import re
def parse_spec(spec: str):
    """
    Parse a general 'name[a:b,c:d]' or 'a=b,c=d' style specification.

    Supports:
      - Bare name, e.g. "foo/bar"
      - Optional bracket suffix with comma-separated entries:
          key:value or key:int_value (underscores allowed)
      - Optional "key=value" pairs outside the bracket.

    Returns:
      name: str or None
      kv_dict: dict of key/value pairs (all combined)
    """

    def _parse_kv_string(s: str) -> dict:
        """Parse comma-separated key=value pairs, e.g. 'a=1,b=2'."""
        return dict(part.split("=", 1) for part in s.split(",") if "=" in part)

    s = spec.strip()

    # Extract bracket content if present
    m = re.search(r"\[(.*?)\]$", s)
    bracket_kvs = {}
    numeric_kvs = {}
    if m:
        bracket = m.group(1).strip()
        if bracket:
            for part in bracket.split(","):
                part = part.strip()
                if not part:
                    continue
                if ":" not in part:
                    raise ValueError(
                        f"Invalid entry '{part}' in '{spec}' (expected key:value)."
                    )
                key, value = part.split(":", 1)
                key = key.strip()
                value = value.strip()

                # Integers (with optional underscores)
                if re.fullmatch(r"\d(?:_?\d)*", value):
                    numeric_kvs[key] = int(value.replace("_", ""))
                else:
                    bracket_kvs[key] = value

        # Remove the bracket suffix from the working string
        s = s[: m.start()].rstrip()

    # Determine name (if any) and parse outer kvs (if any)
    name = None
    if "=" in s:
        kv_dict = dict(_parse_kv_string(s))
    else:
        kv_dict = {}
        if s:
            name = s  # could represent a dataset, resource, or identifier

    # Merge: bracket options and numeric keys last
    kv_dict.update(bracket_kvs)
    kv_dict.update(numeric_kvs)

    return name, kv_dict

In [None]:
import re
dataset_args='/data/lxy/diffusion/data/alpaca-zh-gpt[train:2000,test:200]'
specs = [p.strip() for p in re.split(r"[|+]", dataset_args) if p.strip()]
print(specs)


dataset_name_or_path, kvs = parse_spec(specs[0])
print(dataset_name_or_path)
print(kvs)



['/data/lxy/diffusion/data/alpaca-zh-gpt[train:2000,test:200]']
/data/lxy/diffusion/data/alpaca-zh-gpt
{'train': 2000, 'test': 200}


In [9]:
from datasets import load_from_disk
orig_data_path='/data/lxy/diffusion/data/alpaca-zh-gpt'
orig_dataset = load_from_disk(orig_data_path)
print(orig_dataset)

print(orig_dataset[0])

Dataset({
    features: ['input_ids', 'labels', 'prompt_len'],
    num_rows: 48818
})
{'input_ids': [126080, 27, 91, 7351, 20679, 2983, 95591, 3840, 27, 91, 486, 20679, 2983, 95591, 198, 198, 5583, 30345, 7009, 12767, 311, 27, 91, 68, 335, 2983, 91, 3583, 91, 7351, 20679, 2983, 95591, 598, 10450, 27, 91, 486, 20679, 2983, 95591, 198, 198, 50219, 5583, 30345, 7009, 12767, 629, 198, 198, 16, 13, 112204, 4426, 3004, 51234, 1262, 14364, 19586, 5153, 5353, 42347, 389, 33687, 1193, 25327, 9428, 8712, 61409, 4659, 31874, 18956, 5938, 2366, 30057, 6631, 25490, 311, 198, 198, 17, 13, 220, 28836, 12936, 51234, 17813, 59260, 16692, 88873, 34672, 77601, 538, 17512, 15771, 34739, 22426, 9126, 18654, 851, 6373, 10187, 17512, 538, 8280, 8618, 2589, 5583, 30345, 72685, 311, 198, 198, 18, 13, 220, 20248, 22605, 311, 20248, 69347, 4659, 55041, 33897, 16953, 7877, 995, 5780, 220, 22, 12, 23, 220, 55336, 20248, 311, 11262, 20248, 30057, 22894, 6513, 24343, 4426, 8027, 2366, 4043, 4214, 20229, 73536, 311, 

In [14]:
train_data=orig_dataset.select(range(2000))
test_data=orig_dataset.select(range(2000,2200))

print(train_data[0])
print(test_data[0])

{'input_ids': [126080, 27, 91, 7351, 20679, 2983, 95591, 3840, 27, 91, 486, 20679, 2983, 95591, 198, 198, 5583, 30345, 7009, 12767, 311, 27, 91, 68, 335, 2983, 91, 3583, 91, 7351, 20679, 2983, 95591, 598, 10450, 27, 91, 486, 20679, 2983, 95591, 198, 198, 50219, 5583, 30345, 7009, 12767, 629, 198, 198, 16, 13, 112204, 4426, 3004, 51234, 1262, 14364, 19586, 5153, 5353, 42347, 389, 33687, 1193, 25327, 9428, 8712, 61409, 4659, 31874, 18956, 5938, 2366, 30057, 6631, 25490, 311, 198, 198, 17, 13, 220, 28836, 12936, 51234, 17813, 59260, 16692, 88873, 34672, 77601, 538, 17512, 15771, 34739, 22426, 9126, 18654, 851, 6373, 10187, 17512, 538, 8280, 8618, 2589, 5583, 30345, 72685, 311, 198, 198, 18, 13, 220, 20248, 22605, 311, 20248, 69347, 4659, 55041, 33897, 16953, 7877, 995, 5780, 220, 22, 12, 23, 220, 55336, 20248, 311, 11262, 20248, 30057, 22894, 6513, 24343, 4426, 8027, 2366, 4043, 4214, 20229, 73536, 311, 27, 91, 68, 335, 2983, 91, 3583, 91, 7351, 20679, 2983, 95591, 598, 10450, 27, 91, 486

In [15]:
train_datasets=orig_dataset.select(range(kvs.get("train",len(orig_dataset))))
test_datasets=orig_dataset.select(range(kvs.get("train",len(orig_dataset)),kvs.get("train",len(orig_dataset))+kvs.get("test",len(orig_dataset))))

In [16]:
print(train_datasets[0])
print(test_datasets[0])

{'input_ids': [126080, 27, 91, 7351, 20679, 2983, 95591, 3840, 27, 91, 486, 20679, 2983, 95591, 198, 198, 5583, 30345, 7009, 12767, 311, 27, 91, 68, 335, 2983, 91, 3583, 91, 7351, 20679, 2983, 95591, 598, 10450, 27, 91, 486, 20679, 2983, 95591, 198, 198, 50219, 5583, 30345, 7009, 12767, 629, 198, 198, 16, 13, 112204, 4426, 3004, 51234, 1262, 14364, 19586, 5153, 5353, 42347, 389, 33687, 1193, 25327, 9428, 8712, 61409, 4659, 31874, 18956, 5938, 2366, 30057, 6631, 25490, 311, 198, 198, 17, 13, 220, 28836, 12936, 51234, 17813, 59260, 16692, 88873, 34672, 77601, 538, 17512, 15771, 34739, 22426, 9126, 18654, 851, 6373, 10187, 17512, 538, 8280, 8618, 2589, 5583, 30345, 72685, 311, 198, 198, 18, 13, 220, 20248, 22605, 311, 20248, 69347, 4659, 55041, 33897, 16953, 7877, 995, 5780, 220, 22, 12, 23, 220, 55336, 20248, 311, 11262, 20248, 30057, 22894, 6513, 24343, 4426, 8027, 2366, 4043, 4214, 20229, 73536, 311, 27, 91, 68, 335, 2983, 91, 3583, 91, 7351, 20679, 2983, 95591, 598, 10450, 27, 91, 486

In [17]:
from datasets import DatasetDict, load_from_disk
results=DatasetDict({"train":train_datasets,"test":test_datasets})
results

DatasetDict({
    train: Dataset({
        features: ['input_ids', 'labels', 'prompt_len'],
        num_rows: 2000
    })
    test: Dataset({
        features: ['input_ids', 'labels', 'prompt_len'],
        num_rows: 200
    })
})