# SFT models

---

Models:
  - 'answerdotai/ModernBERT-base'
  - 'answerdotai/ModernBERT-large'
  - 'google/t5gemma-b-b-prefixlm-it'
  - 'google/t5gemma-s-s-prefixlm'
  - 'google/t5gemma-s-s-prefixlm-it'
  - 'google/t5gemma-s-s-ul2'
  - 'google/t5gemma-s-s-ul2-it'
  - 'google/t5gemma-b-b-ul2'
  - 'google/t5gemma-b-b-prefixlm'
  - 'google/t5gemma-b-b-ul2-it'
  - 'google/t5gemma-l-l-ul2'
  - 'google/t5gemma-l-l-prefixlm'
  - 'google/t5gemma-l-l-ul2-it'
  - 'google/t5gemma-l-l-prefixlm-it'
  - 'google/t5gemma-2-270m-270m'

## Import libs

In [None]:
from trl import SFTTrainer, SFTConfig
import torch
from transformers import AutoModelForCausalLM, AutoTokenizer
import os
from typing import Dict

# SFT config

In [None]:
eval_steps = 120

sft_config = SFTConfig(
    output_dir='../output_sft_models',
    num_train_epochs=1,
    max_length=2048,
    logging_steps=eval_steps,
    save_strategy='no',
    per_device_train_batch_size=1,
    per_device_eval_batch_size=2,
    gradient_accumulation_steps=1,
    optim='adalomo', # adam x4, sgd x2, adagrad x2, adafactor x2, rmsprop x2
    learning_rate=5e-5,
    eval_strategy="steps",
    eval_steps=eval_steps,
    warmup_ratio=0.1,
    gradient_checkpointing=True,
    gradient_checkpointing_kwargs={'use_reentrant': False},
    overwrite_output_dir=True,
    eval_on_start=True,
    bf16_full_eval=True,
    assistant_only_loss=True
)

## SFT func

In [None]:
def sft_model(
    model_name : str = None, 
    models_path : str = None,
    save_models_path : str = None, 
    sft_config : SFTConfig = None,
    test_dataset : Dict = None,
    train_dataset : Dict = None
    ):

    device = 'cuda' if torch.cuda.is_available() else 'cpu'
    input_path = os.path.join(models_path, model_name.replace('/', '--'))
    output_path = os.path.join(save_models_path, model_name.replace('/', '--'))

    model = AutoModelForCausalLM.from_pretrained(input_path).to(device)
    tokenizer = AutoTokenizer.from_pretrained(input_path)

    sft_trainer = SFTTrainer(
        model=model,
        args=sft_config,
        test_dataset=test_dataset,
        train_dataset=train_dataset,
        processing_class=tokenizer
    )
    sft_trainer.train()
    sft_trainer.save_model(output_path)

    return f'{model_name} is trained!'

## Shutdown kernel

In [None]:
import os 
import signal

os.kill(os.getpid(), signal.SIGTERM)