# Load the model

In [None]:
from transformers import AutoModelForCausalLM, AutoTokenizer

In [None]:
checkpoint = "HuggingFaceTB/SmolLM-135M"

tokenizer = AutoTokenizer.from_pretrained(checkpoint)
tokenizer.pad_token = tokenizer.eos_token
model = AutoModelForCausalLM.from_pretrained(checkpoint,device_map='auto')

print(f"Model is hosted on: {model.device}")

# Load the data

In [None]:
import pickle as pkl

In [None]:
train_data = pkl.load(open("../Dataset/ChatTrain2.pkl","rb"))
val_data = pkl.load(open("../Dataset/ChatTest2.pkl","rb"))

train_data, val_data

# Setup training infra

In [None]:
from trl import SFTTrainer, SFTConfig

In [None]:
sft_config = SFTConfig(
    output_dir="./ChatTraining_3_Checkpoints",
    max_steps=200, 
    per_device_train_batch_size=8,
    gradient_accumulation_steps=16,
    learning_rate=5e-5,
    logging_steps=50,
    save_steps=60,
    save_total_limit=5,
    eval_strategy="steps",
    eval_steps=50,
    report_to = 'none',
)

In [None]:
trainer = SFTTrainer(
    model = model,
    train_dataset = train_data,
    eval_dataset = val_data,
    args = sft_config)

# Train the model

In [None]:
trainer.train()
# trainer.train(resume_from_checkpoint=True)

In [None]:
import torch
torch.cuda.empty_cache()

# Save the model

In [None]:
trainer.save_model("SmolLM-Our-Instruct-vxx")