# 模型训练

In [None]:
import warnings
warnings.filterwarnings('ignore')

## 加载模型

In [None]:
import torch
from transformers import AutoModelForCausalLM

pretrained_model = AutoModelForCausalLM.from_pretrained(
    "./models/upstage/TinySolar-308m-4k-init",
    device_map="cpu",
    torch_dtype=torch.bfloat16,
    use_cache=False,
)

In [None]:
pretrained_model

## 加载数据

In [None]:
import datasets
from torch.utils.data import Dataset


class CustomDataset(Dataset):
    def __init__(self, args, split="train"):
        """Initializes the custom dataset object."""
        self.args = args
        self.dataset = datasets.load_dataset(
            "parquet",
            data_files=args.dataset_name,
            split=split
        )

    def __len__(self):
        """Returns the number of samples in the dataset."""
        return len(self.dataset)

    def __getitem__(self, idx):
        """
        Retrieves a single data sample from the dataset 
        at the specified index
        """
        # Convert the lists to a LongTensor for PyTorch
        input_ids = torch.LongTensor(self.dataset[idx]["input_ids"])
        labels = torch.LongTensor(self.dataset[idx]["input_ids"])

        # Return the sample as a dictionary
        return {"input_ids": input_ids, "labels": labels}

## 定义训练函数

In [None]:
from dataclasses import dataclass, field
import transformers


@dataclass
class CustomArguments(transformers.TrainingArguments):
    dataset_name: str = field(                           # Dataset configuration
        default="./parquet/packaged_pretrain_dataset.parquet")
    # Number of subprocesses for data preprocessing
    num_proc: int = field(default=1)
    # Maximum sequence length
    max_seq_length: int = field(default=32)

    # Core training configurations
    # Random seed for initialization, ensuring reproducibility
    seed: int = field(default=0)
    # Optimizer, here it's AdamW implemented in PyTorch
    optim: str = field(default="adamw_torch")
    # Number of maximum training steps
    max_steps: int = field(default=30)
    # Batch size per device during training
    per_device_train_batch_size: int = field(default=2)

    # Other training configurations
    # Initial learning rate for the optimizer
    learning_rate: float = field(default=5e-5)
    weight_decay: float = field(default=0)               # Weight decay
    # Number of steps for the learning rate warmup phase
    warmup_steps: int = field(default=10)
    # Type of learning rate scheduler
    lr_scheduler_type: str = field(default="linear")
    # Enable gradient checkpointing to save memory
    gradient_checkpointing: bool = field(default=True)
    # Number of subprocesses for data loading
    dataloader_num_workers: int = field(default=2)
    # Use bfloat16 precision for training on supported hardware
    bf16: bool = field(default=True)
    # Number of steps to accumulate gradients before updating model weights
    gradient_accumulation_steps: int = field(default=1)

    # Logging configuration
    # Frequency of logging training information
    logging_steps: int = field(default=3)
    # Destination for logging (e.g., WandB, TensorBoard)
    report_to: str = field(default="none")

    # Saving configuration
    # save_strategy: str = field(default="steps")          # Can be replaced with "epoch"
    # save_steps: int = field(default=3)                   # Frequency of saving training checkpoint
    # save_total_limit: int = field(default=2)             # The total number of checkpoints to be saved

In [None]:
parser = transformers.HfArgumentParser(CustomArguments)
args, = parser.parse_args_into_dataclasses(
    args=["--output_dir", "output"]
)

In [None]:
train_dataset = CustomDataset(args=args)

In [None]:
print("Input shape: ", train_dataset[0]['input_ids'].shape)

## 训练并监测损失

In [None]:
from transformers import Trainer, TrainingArguments, TrainerCallback

# Define a custom callback to log the loss values


class LossLoggingCallback(TrainerCallback):
    def on_log(self, args, state, control, logs=None, **kwargs):
        if logs is not None:
            self.logs.append(logs)

    def __init__(self):
        self.logs = []


# Initialize the callback
loss_logging_callback = LossLoggingCallback()

In [None]:
from transformers import Trainer

trainer = Trainer(
    model=pretrained_model,
    args=args,
    train_dataset=train_dataset,
    eval_dataset=None,
    callbacks=[loss_logging_callback]
)

trainer.train()

In [None]:
# Saving configuration
# save_strategy: str = field(default="steps")          # Can be replaced with "epoch"
# save_steps: int = field(default=3)                   # Frequency of saving training checkpoint
# save_total_limit: int = field(default=2)             # The total number of checkpoints to be saved

## 检查中间检查点权重性能

In [None]:
from transformers import AutoTokenizer, TextStreamer
model_name_or_path = "./models/upstage/TinySolar-248m-4k"
tokenizer = AutoTokenizer.from_pretrained(model_name_or_path)

In [None]:
from transformers import AutoTokenizer, TextStreamer, AutoModelForCausalLM
import torch

model_name_or_path = "./models/output/checkpoint-10000"
model2 = AutoModelForCausalLM.from_pretrained(
    model_name_or_path,
    device_map="auto",
    torch_dtype=torch.bfloat16,
)

In [None]:
prompt = "I am an engineer. I love"

inputs = tokenizer(prompt, return_tensors="pt").to(model2.device)

streamer = TextStreamer(
    tokenizer,
    skip_prompt=True,
    skip_special_tokens=True
)

outputs = model2.generate(
    **inputs,
    streamer=streamer,
    use_cache=True,
    max_new_tokens=64,
    do_sample=True,
    temperature=1.0,
)