In [1]:
!pip install -q -U transformers
!pip install -q -U accelerate
!pip install -q -U datasets
!pip install -q -U trl

Import all the necessary packages.

In [1]:
import torch, multiprocessing
from datasets import load_dataset
from transformers import (
    AutoModelForCausalLM,
    AutoTokenizer,
    TrainingArguments
)
from trl import SFTTrainer, SFTConfig, DPOTrainer, DPOConfig

  from .autonotebook import tqdm as notebook_tqdm


#SmolLM 135M
# Distilled Supervised Fine-tuning

First, activate the use of bfloat16 and FlashAttenion if they are compatible with your GPU.
Then, load the tokenizer and configure padding

In [3]:
major_version, minor_version = torch.cuda.get_device_capability()
if major_version >= 8:
  !pip install flash-attn
  attn_implementation='flash_attention_2'
  print("Your GPU is compatible with FlashAttention.")
else:
  attn_implementation='eager'
  print("Your GPU is not compatible with FlashAttention.")

model_name = "HuggingFaceTB/SmolLM-135M"
#Tokenizer
tokenizer = AutoTokenizer.from_pretrained(model_name, use_fast=True)
tokenizer.pad_token = "<|im_end|>"
tokenizer.pad_token_id = 2
tokenizer.padding_side = 'left' #Necessary for FlashAttention compatibility




Looking in indexes: https://pypi.org/simple, https://pypi.ngc.nvidia.com
Your GPU is compatible with FlashAttention.


In [4]:
chat_template = """
Human: {human_input}
Assistant: {assistant_output}
"""
tokenizer.chat_template = chat_template

Load the version of ultrachat prepared by Hugging Face. I only load 5% of the test split to speed up validation.

In [5]:
dataset_train_sft = load_dataset("HuggingFaceH4/ultrachat_200k", split="train_sft")
dataset_test_sft = load_dataset("HuggingFaceH4/ultrachat_200k", split="test_sft[:5%]")

Load the model that we will train with SFT and activate gradient checkpointing to save memory.

In [6]:
model = AutoModelForCausalLM.from_pretrained(
          model_name, attn_implementation=attn_implementation, device_map={"": 0}
)
model.gradient_checkpointing_enable(gradient_checkpointing_kwargs={'use_reentrant':True})

You are attempting to use Flash Attention 2.0 without specifying a torch dtype. This might lead to unexpected behaviour


For this demonstration, I trained for only 4000 steps. One epoch would be ideal.

In [7]:
training_arguments = SFTConfig(
        output_dir="./sft_smollm_135M/",
        eval_strategy="steps",
        do_eval=True,
        optim="adamw_torch",
        per_device_train_batch_size=8,
        gradient_accumulation_steps=2,
        per_device_eval_batch_size=8,
        log_level="debug",
        save_steps=500,
        logging_steps=50,
        learning_rate=2e-5,
        fp16= not torch.cuda.is_bf16_supported(),
        bf16= torch.cuda.is_bf16_supported(),
        eval_steps=50,
        max_steps=4000,
        warmup_steps=30,
        max_seq_length=2048,
        lr_scheduler_type="linear",
)

Start training:

In [8]:
trainer = SFTTrainer(
        model=model,
        train_dataset=dataset_train_sft,
        eval_dataset=dataset_test_sft,
        tokenizer=tokenizer,
        args=training_arguments,
)

trainer.train()

Map: 100%|██████████| 207865/207865 [00:31<00:00, 6661.73 examples/s]
Map: 100%|██████████| 1156/1156 [00:00<00:00, 6559.58 examples/s]
max_steps is given, it will override any value given in num_train_epochs
Using auto half precision backend
Currently training with a batch size of: 8
***** Running training *****
  Num examples = 207,865
  Num Epochs = 1
  Instantaneous batch size per device = 8
  Total train batch size (w. parallel, distributed & accumulation) = 16
  Gradient Accumulation steps = 2
  Total optimization steps = 4,000
  Number of trainable parameters = 134,515,008
Automatic Weights & Biases logging enabled, to disable set os.environ["WANDB_DISABLED"] = "true"
[34m[1mwandb[0m: Using wandb-core as the SDK backend. Please refer to https://wandb.me/wandb-core for more information.
[34m[1mwandb[0m: Logging into wandb.ai. (Learn how to deploy a W&B server locally: https://wandb.me/wandb-server)
[34m[1mwandb[0m: You can find your API key in your browser here: https://

`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`.
The input hidden states seems to be silently casted in float32, this might be related to the fact you have upcasted embedding or layer norm layers in float32. We will cast back the input in torch.bfloat16.
Detected flash_attn version: 2.6.3


Step,Training Loss,Validation Loss
50,1.5396,0.000199
100,0.0001,2.7e-05
150,0.0,1.6e-05
200,0.0,1.2e-05
250,0.0,9e-06
300,0.0,8e-06
350,0.0,7e-06
400,0.0,5e-06
450,0.0,4e-06
500,0.0,3e-06



***** Running Evaluation *****
  Num examples = 1156
  Batch size = 8

***** Running Evaluation *****
  Num examples = 1156
  Batch size = 8

***** Running Evaluation *****
  Num examples = 1156
  Batch size = 8

***** Running Evaluation *****
  Num examples = 1156
  Batch size = 8

***** Running Evaluation *****
  Num examples = 1156
  Batch size = 8

***** Running Evaluation *****
  Num examples = 1156
  Batch size = 8

***** Running Evaluation *****
  Num examples = 1156
  Batch size = 8

***** Running Evaluation *****
  Num examples = 1156
  Batch size = 8

***** Running Evaluation *****
  Num examples = 1156
  Batch size = 8

***** Running Evaluation *****
  Num examples = 1156
  Batch size = 8
Saving model checkpoint to ./sft_smollm_135M/checkpoint-500
Configuration saved in ./sft_smollm_135M/checkpoint-500/config.json
Configuration saved in ./sft_smollm_135M/checkpoint-500/generation_config.json
Model weights saved in ./sft_smollm_135M/checkpoint-500/model.safetensors
tokenizer

TrainOutput(global_step=4000, training_loss=0.01924704824045125, metrics={'train_runtime': 3833.8411, 'train_samples_per_second': 16.693, 'train_steps_per_second': 1.043, 'total_flos': 734078287872000.0, 'train_loss': 0.01924704824045125, 'epoch': 0.3078817733990148})

# Distilled DPO

Load the model that will be trained with DPO.

In [2]:
major_version, minor_version = torch.cuda.get_device_capability()
if major_version >= 8:
  !pip install flash-attn
  attn_implementation='flash_attention_2'
  print("Your GPU is compatible with FlashAttention.")
else:

  attn_implementation='eager'
  print("Your GPU is not compatible with FlashAttention.")

model_name = "HuggingFaceTB/SmolLM-135M"
#Tokenizer
tokenizer = AutoTokenizer.from_pretrained(model_name, use_fast=True)
tokenizer.pad_token = "<|im_end|>"
tokenizer.pad_token_id = 2
tokenizer.padding_side = 'left' #Necessary for FlashAttention compatibility

model = AutoModelForCausalLM.from_pretrained(
          model_name, attn_implementation=attn_implementation, device_map={"": 0}
)
model.gradient_checkpointing_enable(gradient_checkpointing_kwargs={'use_reentrant':True})

Looking in indexes: https://pypi.org/simple, https://pypi.ngc.nvidia.com
Your GPU is compatible with FlashAttention.


You are attempting to use Flash Attention 2.0 without specifying a torch dtype. This might lead to unexpected behaviour


We will use as a reference model our checkpoint trained with SFT.

In [3]:
ref_model = AutoModelForCausalLM.from_pretrained(
          "./sft_smollm_135M/checkpoint-4000", attn_implementation=attn_implementation, device_map={"": 0}
)

In [4]:
ref_model

LlamaForCausalLM(
  (model): LlamaModel(
    (embed_tokens): Embedding(49152, 576)
    (layers): ModuleList(
      (0-29): 30 x LlamaDecoderLayer(
        (self_attn): LlamaFlashAttention2(
          (q_proj): Linear(in_features=576, out_features=576, bias=False)
          (k_proj): Linear(in_features=576, out_features=192, bias=False)
          (v_proj): Linear(in_features=576, out_features=192, bias=False)
          (o_proj): Linear(in_features=576, out_features=576, bias=False)
          (rotary_emb): LlamaRotaryEmbedding()
        )
        (mlp): LlamaMLP(
          (gate_proj): Linear(in_features=576, out_features=1536, bias=False)
          (up_proj): Linear(in_features=576, out_features=1536, bias=False)
          (down_proj): Linear(in_features=1536, out_features=576, bias=False)
          (act_fn): SiLU()
        )
        (input_layernorm): LlamaRMSNorm((576,), eps=1e-05)
        (post_attention_layernorm): LlamaRMSNorm((576,), eps=1e-05)
      )
    )
    (norm): LlamaRMSNo

Format UltraFeedback with a default chat template for DPO training.

In [5]:
dataset = load_dataset("HuggingFaceH4/ultrafeedback_binarized", split=["train_prefs","test_prefs"])

tokenizer.chat_template = "{% if not add_generation_prompt is defined %}{% set add_generation_prompt = false %}{% endif %}{% for message in messages %}{{'<|im_start|>' + message['role'] + '\n' + message['content'] + '<|im_end|>' + '\n'}}{% endfor %}{% if add_generation_prompt %}{{ '<|im_start|>assistant\n' }}{% endif %}"

def process(row):
    row["chosen"] = tokenizer.apply_chat_template(row["chosen"], tokenize=False)
    row["rejected"] = tokenizer.apply_chat_template(row["rejected"], tokenize=False)
    return row

dataset[0] = dataset[0].map(
    process,
    num_proc= multiprocessing.cpu_count(),
    load_from_cache_file=False,
)

dataset[1] = dataset[1].map(
    process,
    num_proc= multiprocessing.cpu_count(),
    load_from_cache_file=False,
)

print(dataset)

Generating train_prefs split: 100%|██████████| 61135/61135 [00:00<00:00, 64159.41 examples/s]
Generating train_sft split: 100%|██████████| 61135/61135 [00:00<00:00, 63869.72 examples/s]
Generating test_prefs split: 100%|██████████| 2000/2000 [00:00<00:00, 60216.70 examples/s]
Generating test_sft split: 100%|██████████| 1000/1000 [00:00<00:00, 51992.71 examples/s]
Generating train_gen split: 100%|██████████| 61135/61135 [00:00<00:00, 72817.88 examples/s]
Generating test_gen split: 100%|██████████| 1000/1000 [00:00<00:00, 58473.50 examples/s]
Map (num_proc=16): 100%|██████████| 61135/61135 [00:02<00:00, 24502.51 examples/s]
Map (num_proc=16): 100%|██████████| 2000/2000 [00:00<00:00, 5628.58 examples/s]


[Dataset({
    features: ['prompt', 'prompt_id', 'chosen', 'rejected', 'messages', 'score_chosen', 'score_rejected'],
    num_rows: 61135
}), Dataset({
    features: ['prompt', 'prompt_id', 'chosen', 'rejected', 'messages', 'score_chosen', 'score_rejected'],
    num_rows: 2000
})]


For this demonstration, I trained for only 4000 steps. DPO learns very slowly so one epoch would be ideal. I didn't search for a better learning rate. A higher learning rate may yield better results.

In [6]:
training_arguments = DPOConfig(
        output_dir="./dpo_smollm_135M/",
        evaluation_strategy="steps",
        do_eval=True,
        optim="adamw_torch",
        per_device_train_batch_size=8,
        gradient_accumulation_steps=2,
        per_device_eval_batch_size=8,
        log_level="debug",
        save_steps=500,
        fp16= not torch.cuda.is_bf16_supported(),
        bf16= torch.cuda.is_bf16_supported(),
        logging_steps=50,
        learning_rate=1e-7,
        eval_steps=50,
        max_steps=4000,
        warmup_steps=30,
        lr_scheduler_type="linear",
        beta=0.1,
)



Start DPO training

In [7]:
trainer = DPOTrainer(
    model,
    ref_model=ref_model,
    args=training_arguments,
    train_dataset=dataset[0],
    eval_dataset=dataset[1],
    tokenizer=tokenizer
)

trainer.train()

Extracting prompt from train dataset: 100%|██████████| 61135/61135 [00:10<00:00, 5645.73 examples/s]
Applying chat template to train dataset: 100%|██████████| 61135/61135 [00:10<00:00, 5709.83 examples/s]
Extracting prompt from eval dataset: 100%|██████████| 2000/2000 [00:00<00:00, 5896.59 examples/s]
Applying chat template to eval dataset: 100%|██████████| 2000/2000 [00:00<00:00, 6133.08 examples/s]
Tokenizing train dataset: 100%|██████████| 61135/61135 [07:59<00:00, 127.54 examples/s]
Tokenizing eval dataset: 100%|██████████| 2000/2000 [00:15<00:00, 129.95 examples/s]
max_steps is given, it will override any value given in num_train_epochs
Using auto half precision backend
Currently training with a batch size of: 8
The following columns in the training set don't have a corresponding argument in `LlamaForCausalLM.forward` and have been ignored: chosen, score_rejected, rejected, prompt_id, messages, prompt, score_chosen. If chosen, score_rejected, rejected, prompt_id, messages, prompt,

`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`.
The input hidden states seems to be silently casted in float32, this might be related to the fact you have upcasted embedding or layer norm layers in float32. We will cast back the input in torch.bfloat16.
Detected flash_attn version: 2.6.3


OutOfMemoryError: CUDA out of memory. Tried to allocate 3.71 GiB. GPU 

#SmolLM 360M
The same but for the 360M version.

In [None]:
import torch, multiprocessing
from datasets import load_dataset
from transformers import (
    AutoModelForCausalLM,
    AutoTokenizer,
    TrainingArguments
)
from trl import SFTTrainer, SFTConfig, DPOTrainer, DPOConfig

In [None]:
major_version, minor_version = torch.cuda.get_device_capability()
if major_version >= 8:
  !pip install flash-attn
  attn_implementation='flash_attention_2'
  print("Your GPU is compatible with FlashAttention.")
else:
  attn_implementation='eager'
  print("Your GPU is not compatible with FlashAttention.")

model_name = "HuggingFaceTB/SmolLM-360M"
#Tokenizer
tokenizer = AutoTokenizer.from_pretrained(model_name, use_fast=True)
tokenizer.pad_token = "<|im_end|>"
tokenizer.pad_token_id = 2
tokenizer.padding_side = 'left' #Necessary for FlashAttention compatibility

dataset_train_sft = load_dataset("HuggingFaceH4/ultrachat_200k", split="train_sft")
dataset_test_sft = load_dataset("HuggingFaceH4/ultrachat_200k", split="test_sft[:5%]")

model = AutoModelForCausalLM.from_pretrained(
          model_name, attn_implementation=attn_implementation, device_map={"": 0}
)
model.gradient_checkpointing_enable(gradient_checkpointing_kwargs={'use_reentrant':True})

training_arguments = SFTConfig(
        output_dir="./sft_smollm_360M/",
        eval_strategy="steps",
        do_eval=True,
        optim="adamw_torch",
        per_device_train_batch_size=8,
        gradient_accumulation_steps=2,
        per_device_eval_batch_size=8,
        log_level="debug",
        save_steps=500,
        logging_steps=50,
        learning_rate=2e-5,
        fp16= not torch.cuda.is_bf16_supported(),
        bf16= torch.cuda.is_bf16_supported(),
        eval_steps=50,
        max_steps=4000,
        warmup_steps=30,
        max_seq_length=2048,
        lr_scheduler_type="linear",
)

trainer = SFTTrainer(
        model=model,
        train_dataset=dataset_train_sft,
        eval_dataset=dataset_test_sft,
        tokenizer=tokenizer,
        args=training_arguments,
)

trainer.train()

Your GPU is compatible with FlashAttention.


tokenizer_config.json:   0%|          | 0.00/3.69k [00:00<?, ?B/s]

vocab.json:   0%|          | 0.00/801k [00:00<?, ?B/s]

merges.txt:   0%|          | 0.00/466k [00:00<?, ?B/s]

tokenizer.json:   0%|          | 0.00/2.10M [00:00<?, ?B/s]

special_tokens_map.json:   0%|          | 0.00/831 [00:00<?, ?B/s]

Downloading readme:   0%|          | 0.00/4.44k [00:00<?, ?B/s]

Downloading data:   0%|          | 0.00/244M [00:00<?, ?B/s]

Downloading data:   0%|          | 0.00/244M [00:00<?, ?B/s]

Downloading data:   0%|          | 0.00/244M [00:00<?, ?B/s]

Downloading data:   0%|          | 0.00/81.2M [00:00<?, ?B/s]

Downloading data:   0%|          | 0.00/244M [00:00<?, ?B/s]

Downloading data:   0%|          | 0.00/243M [00:00<?, ?B/s]

Downloading data:   0%|          | 0.00/243M [00:00<?, ?B/s]

Downloading data:   0%|          | 0.00/80.4M [00:00<?, ?B/s]

Generating train_sft split:   0%|          | 0/207865 [00:00<?, ? examples/s]

Generating test_sft split:   0%|          | 0/23110 [00:00<?, ? examples/s]

Generating train_gen split:   0%|          | 0/256032 [00:00<?, ? examples/s]

Generating test_gen split:   0%|          | 0/28304 [00:00<?, ? examples/s]

config.json:   0%|          | 0.00/725 [00:00<?, ?B/s]

model.safetensors:   0%|          | 0.00/1.45G [00:00<?, ?B/s]

You are attempting to use Flash Attention 2.0 without specifying a torch dtype. This might lead to unexpected behaviour
Flash Attention 2.0 only supports torch.float16 and torch.bfloat16 dtypes, but the current dype in LlamaForCausalLM is torch.float32. You should run training or inference using Automatic Mixed-Precision via the `with torch.autocast(device_type='torch_device'):` decorator, or load the model with the `torch_dtype` argument. Example: `model = AutoModel.from_pretrained("openai/whisper-tiny", attn_implementation="flash_attention_2", torch_dtype=torch.float16)`
Flash Attention 2.0 only supports torch.float16 and torch.bfloat16 dtypes, but the current dype in LlamaModel is torch.float32. You should run training or inference using Automatic Mixed-Precision via the `with torch.autocast(device_type='torch_device'):` decorator, or load the model with the `torch_dtype` argument. Example: `model = AutoModel.from_pretrained("openai/whisper-tiny", attn_implementation="flash_attentio

generation_config.json:   0%|          | 0.00/111 [00:00<?, ?B/s]



Map:   0%|          | 0/207865 [00:00<?, ? examples/s]

No chat template is set for this tokenizer, falling back to a default class-level template. This is very error-prone, because models are often trained with templates different from the class default! Default chat templates are a legacy feature and will be removed in Transformers v4.43, at which point any code depending on them will stop working. We recommend setting a valid chat template before then to ensure that this model continues working without issues.


Map:   0%|          | 0/1156 [00:00<?, ? examples/s]

max_steps is given, it will override any value given in num_train_epochs
Using auto half precision backend
Currently training with a batch size of: 8
***** Running training *****
  Num examples = 207,865
  Num Epochs = 1
  Instantaneous batch size per device = 8
  Total train batch size (w. parallel, distributed & accumulation) = 16
  Gradient Accumulation steps = 2
  Total optimization steps = 4,000
  Number of trainable parameters = 361,821,120
`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`.
The input hidden states seems to be silently casted in float32, this might be related to the fact you have upcasted embedding or layer norm layers in float32. We will cast back the input in torch.bfloat16.
Detected flash_attn version: 2.6.3


Step,Training Loss,Validation Loss
50,1.8196,1.706369
100,1.6092,1.579805
150,1.5273,1.549133
200,1.5498,1.533507
250,1.5261,1.52407
300,1.5073,1.517131
350,1.4875,1.511837
400,1.4945,1.507245
450,1.5003,1.503666
500,1.4981,1.500105



***** Running Evaluation *****
  Num examples = 1156
  Batch size = 8
We detected that you are passing `past_key_values` as a tuple and this is deprecated and will be removed in v4.43. Please use an appropriate `Cache` class (https://huggingface.co/docs/transformers/v4.41.3/en/internal/generation_utils#transformers.Cache)

***** Running Evaluation *****
  Num examples = 1156
  Batch size = 8

***** Running Evaluation *****
  Num examples = 1156
  Batch size = 8

***** Running Evaluation *****
  Num examples = 1156
  Batch size = 8

***** Running Evaluation *****
  Num examples = 1156
  Batch size = 8

***** Running Evaluation *****
  Num examples = 1156
  Batch size = 8

***** Running Evaluation *****
  Num examples = 1156
  Batch size = 8

***** Running Evaluation *****
  Num examples = 1156
  Batch size = 8

***** Running Evaluation *****
  Num examples = 1156
  Batch size = 8

***** Running Evaluation *****
  Num examples = 1156
  Batch size = 8
Saving model checkpoint to ./drive/M

Step,Training Loss,Validation Loss
50,1.8196,1.706369
100,1.6092,1.579805
150,1.5273,1.549133
200,1.5498,1.533507
250,1.5261,1.52407
300,1.5073,1.517131
350,1.4875,1.511837
400,1.4945,1.507245
450,1.5003,1.503666
500,1.4981,1.500105



***** Running Evaluation *****
  Num examples = 1156
  Batch size = 8

***** Running Evaluation *****
  Num examples = 1156
  Batch size = 8

***** Running Evaluation *****
  Num examples = 1156
  Batch size = 8

***** Running Evaluation *****
  Num examples = 1156
  Batch size = 8

***** Running Evaluation *****
  Num examples = 1156
  Batch size = 8

***** Running Evaluation *****
  Num examples = 1156
  Batch size = 8

***** Running Evaluation *****
  Num examples = 1156
  Batch size = 8

***** Running Evaluation *****
  Num examples = 1156
  Batch size = 8

***** Running Evaluation *****
  Num examples = 1156
  Batch size = 8

***** Running Evaluation *****
  Num examples = 1156
  Batch size = 8
Saving model checkpoint to ./drive/MyDrive/sft_smollm_360M/checkpoint-4000
Configuration saved in ./drive/MyDrive/sft_smollm_360M/checkpoint-4000/config.json
Configuration saved in ./drive/MyDrive/sft_smollm_360M/checkpoint-4000/generation_config.json
Model weights saved in ./drive/MyDrive

TrainOutput(global_step=4000, training_loss=1.4588746814727782, metrics={'train_runtime': 29672.85, 'train_samples_per_second': 2.157, 'train_steps_per_second': 0.135, 'total_flos': 2.279362120992e+17, 'train_loss': 1.4588746814727782, 'epoch': 0.3078817733990148})

In [None]:
major_version, minor_version = torch.cuda.get_device_capability()
if major_version >= 8:
  !pip install flash-attn
  attn_implementation='flash_attention_2'
  print("Your GPU is compatible with FlashAttention.")
else:

  attn_implementation='eager'
  print("Your GPU is not compatible with FlashAttention.")

model_name = "HuggingFaceTB/SmolLM-360M"
#Tokenizer
tokenizer = AutoTokenizer.from_pretrained(model_name, use_fast=True)
tokenizer.pad_token = "<|im_end|>"
tokenizer.pad_token_id = 2
tokenizer.padding_side = 'left' #Necessary for FlashAttention compatibility

model = AutoModelForCausalLM.from_pretrained(
          model_name, attn_implementation=attn_implementation, device_map={"": 0}
)
model.gradient_checkpointing_enable(gradient_checkpointing_kwargs={'use_reentrant':True})

ref_model = AutoModelForCausalLM.from_pretrained(
          "./sft_smollm_360M/checkpoint-4000", attn_implementation=attn_implementation, device_map={"": 0}
)

dataset = load_dataset("HuggingFaceH4/ultrafeedback_binarized", split=["train_prefs","test_prefs"])

tokenizer.chat_template = "{% if not add_generation_prompt is defined %}{% set add_generation_prompt = false %}{% endif %}{% for message in messages %}{{'<|im_start|>' + message['role'] + '\n' + message['content'] + '<|im_end|>' + '\n'}}{% endfor %}{% if add_generation_prompt %}{{ '<|im_start|>assistant\n' }}{% endif %}"


def process(row):
    row["chosen"] = tokenizer.apply_chat_template(row["chosen"], tokenize=False)
    row["rejected"] = tokenizer.apply_chat_template(row["rejected"], tokenize=False)
    return row

dataset[0] = dataset[0].map(
    process,
    num_proc= multiprocessing.cpu_count(),
    load_from_cache_file=False,
)

dataset[1] = dataset[1].map(
    process,
    num_proc= multiprocessing.cpu_count(),
    load_from_cache_file=False,
)

print(dataset)

training_arguments = DPOConfig(
        output_dir="./dpo_smollm_360M/",
        evaluation_strategy="steps",
        do_eval=True,
        optim="adamw_torch",
        per_device_train_batch_size=8,
        gradient_accumulation_steps=2,
        per_device_eval_batch_size=8,
        log_level="debug",
        save_steps=500,
        fp16= not torch.cuda.is_bf16_supported(),
        bf16= torch.cuda.is_bf16_supported(),
        logging_steps=50,
        learning_rate=1e-7,
        eval_steps=50,
        max_steps=4000,
        warmup_steps=30,
        lr_scheduler_type="linear",
        beta=0.1,
)

trainer = DPOTrainer(
    model,
    ref_model=ref_model,
    args=training_arguments,
    train_dataset=dataset[0],
    eval_dataset=dataset[1],
    tokenizer=tokenizer
)

trainer.train()