In [1]:
# cd ..

In [None]:
import os

os.environ["JUPYTER"] = "True"
from hftrainer.trainer.base import BaseTrainer
from speedy import *
from datasets import load_dataset
from transformers import PreTrainedTokenizer
from datasets import Dataset


class CustomTrainer(BaseTrainer):
    def load_datasets(self):
        logger.debug("Loading datasets from provided path.")

        def to_msgs(item):
            return [
                {"role": "user", "content": item["instruction"] + "\n" + item["input"]},
                {"role": "assistant", "content": item["output"]},
            ]

        def preprocess(item, tokenizer: "PreTrainedTokenizer"):
            try:
                text = tokenizer.apply_chat_template(
                    to_msgs(item), return_tensors="pt", tokenize=False
                )
                ids = tokenizer(
                    text,
                    padding="max_length",
                    truncation=True,
                    max_length=self.training_args.model_max_length,
                    return_tensors="pt",
                )
                return {
                    "input_ids": ids.input_ids.flatten(),
                    "attention_mask": ids.attention_mask.flatten(),
                    "labels": ids.input_ids.flatten(),
                }
            except Exception as e:
                logger.error(f"Error in preprocessing: {e}")
                raise

        path = self.data_args.data_path
        logger.debug(f"Loading data from path: {path}")
        data = load_by_ext(path)[:1000]
        ds = Dataset.from_list(data)

        def map_fn(item):
            return preprocess(item, self.tokenizer)

        logger.debug("Mapping dataset with preprocessing function.")
        ds = ds.map(map_fn)
        dataset = ds.train_test_split(test_size=0.1)
        logger.debug("Dataset loaded and split into train and test sets.")
        return dataset


logger.debug("Starting CustomTrainer with provided configuration.")
trainer = CustomTrainer("../config/template_args_macos.yaml", verbose=False)
trainer.train()
logger.debug("Training completed successfully.")

In [5]:
val_loader = trainer.get_eval_dataloader()

In [25]:
from transformers import TextStreamer
model = trainer.model.eval()
tokenizer = trainer.tokenizer
with torch.no_grad():
    data = [{"role": "user", "content": "fuck you",}]
    text = tokenizer.apply_chat_template(data, tokenize=False, add_generation_prompt=True)
    ids = tokenizer(text, return_tensors='pt').to(model.device)
    gen = model.generate(**ids, streamer=TextStreamer(tokenizer))

Setting `pad_token_id` to `eos_token_id`:151643 for open-end generation.


<|im_start|>system
You are a helpful assistant<|im_end|>
<|im_start|>user
fuck you<|im_end|>
<|im_start|>assistant
Fucking you.<|im_end|>
<|endoftext|>


{'input_ids': tensor([[151644,   8948,    198,   2610,    525,    264,  10950,  17847, 151645,
            198, 151644,    872,    198,   3838,    374,    279,   6722,    315,
           9625,     30, 151645,    198, 151644,  77091,    198]]), 'attention_mask': tensor([[1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,
         1]])}

tensor([[151644,   8948,    198,   2610,    525,    264,  10950,  17847, 151645,
            198, 151644,    872,    198,  23227,    551,    279,   2701,   4244,
           1119,    264,  22414,  11652,    624,     83,    541,   2272,    279,
            304, 151645,    198, 151644,  77091,    198,  25749,    304,    279,
          16217,    374,   2480,    315,  45440,     13, 151645,    198, 151643,
         151643, 151643, 151643, 151643, 151643, 151643, 151643, 151643, 151643,
         151643, 151643, 151643, 151643, 151643, 151643, 151643, 151643, 151643,
         151643, 151643, 151643, 151643, 151643, 151643, 151643, 151643, 151643,
         151643, 151643, 151643, 151643, 151643, 151643, 151643, 151643, 151643,
         151643, 151643, 151643, 151643, 151643, 151643, 151643, 151643, 151643,
         151643, 151643, 151643, 151643, 151643, 151643, 151643, 151643, 151643,
         151643, 151643, 151643, 151643, 151643, 151643, 151643, 151643, 151643,
         151643, 151643, 151