In [1]:
import subprocess
import os
from trl import SFTConfig, SFTTrainer, ModelConfig
from datasets import load_dataset
from transformers import AutoTokenizer,AutoModelForCausalLM

result = subprocess.run('bash -c "source /etc/network_turbo && env | grep proxy"', shell=True, capture_output=True, text=True)
output = result.stdout
for line in output.splitlines():
    if '=' in line:
        var, value = line.split('=', 1)
        os.environ[var] = value


dataset_train = load_dataset("stanfordnlp/imdb", split="train")

def prompt_completion_preprocess(example):
    words = example['text'].split()
    prompt = ' '.join(words[:5])
    completion = ' '.join(words[5:])
    return {'prompt': prompt, 'completion': completion}



dataset_train = dataset_train.remove_columns(['label'])

In [2]:
BASE_MODEL = 'Qwen/Qwen2-0.5B'
model_args = ModelConfig(
    model_name_or_path=BASE_MODEL,
    trust_remote_code=True,
    torch_dtype='auto',
)

training_args = SFTConfig(
    output_dir='/root/autodl-tmp/SFT_imdb',
    per_device_train_batch_size=8,
    num_train_epochs=3,
    logging_steps=100,
    report_to = 'none',
    push_to_hub=True,
    hub_model_id='august66/qwen2-sft-final',
    save_steps=3000,         
    save_total_limit=3,     
    save_strategy="steps", 
)

model_kwargs = dict(
    revision = model_args.model_revision,
    trust_remote_code=model_args.trust_remote_code,
    torch_dtype=model_args.torch_dtype,
    use_cache = False if training_args.gradient_checkpointing else None,
    device_map='auto',
)

training_args.model_init_kwargs = model_kwargs
tokenizer = AutoTokenizer.from_pretrained(
    BASE_MODEL,
    trust_remote_code=model_args.trust_remote_code,
    use_fast=True,
)
if tokenizer.pad_token is None:
    tokenizer.pad_token = tokenizer.eos_token





In [3]:
trainer = SFTTrainer(
    model = model_args.model_name_or_path,
    args = training_args,
    train_dataset = dataset_train,
    processing_class = tokenizer
) 
trainer.train()# ✅ manually save final model

Step,Training Loss
100,3.2387
200,3.1856
300,3.2366
400,3.2084
500,3.1804
600,3.2048
700,3.1802
800,3.1738
900,3.1667
1000,3.1532


TrainOutput(global_step=9375, training_loss=3.0733343310546877, metrics={'train_runtime': 2972.0362, 'train_samples_per_second': 25.235, 'train_steps_per_second': 3.154, 'total_flos': 1.0690613019923866e+17, 'train_loss': 3.0733343310546877})

In [7]:
dataset_train

Dataset({
    features: ['text'],
    num_rows: 25000
})