## Unsloth on Amazon SageMaker AI

In this example we kick off a using the ModelTrainer API from the sagemaker sdk. We used the example script from unsloth. The one thing to note here is we have to patch flash_attn to work with torch 2.7.0 with cuda 12.8. The code for this is wrapped in the first function of our unsloth_train.py file in the scripts directory.


In [None]:
from sagemaker.modules.train import ModelTrainer
from sagemaker.modules.configs import SourceCode, Compute, InputData

# image URI for the training job
pytorch_image = "763104351884.dkr.ecr.us-west-2.amazonaws.com/pytorch-training:2.6.0-gpu-py312-cu126-ubuntu22.04-sagemaker"
# you can find all available images here
# https://docs.aws.amazon.com/sagemaker/latest/dg-ecr-paths/sagemaker-algo-docker-registry-paths.html

# define the script to be run
source_code = SourceCode(
    source_dir="scripts",
    requirements="requirements.txt",
    entry_script="unsloth_train.py",
)

# Compute configuration for the training job
compute = Compute(
    instance_count=1,
    instance_type="ml.g6e.2xlarge",
    # instance_type="ml.p4d.24xlarge",
    volume_size_in_gb=96,
    keep_alive_period_in_seconds=3600,
)

# define the ModelTrainer
model_trainer = ModelTrainer(
    training_image=pytorch_image,
    source_code=source_code,
    base_job_name="unsloth-gemma3-4b-it",
    compute=compute,
)

# pass the input data
# input_data = InputData(
#     channel_name="train",
#     data_source=training_input_path,  #s3 path where training data is stored
# )

# start the training job
model_trainer.train(wait=False)

---
---

### Appendix

Adding wandb to the trainer to track metrics. You can also use `wandb` to visualize the training process. To do this, you need to

1. add wandb and dotenv-python to the requirements.txt file in the scripts directory.
2. Add a '.env' file in the root of your directory with your wandb key.
3. Then pass the wandb key to the training job as an environment variable.
4. Finally in the unsloth_train.py file update the 'reports_to' argument to 'wandb' in the SFTConfig.


In [None]:
from dotenv import load_dotenv

load_dotenv()
import os

env = {'WANDB_API_KEY': os.getenv('WANDB_API_KEY')}

model_trainer = ModelTrainer(
    training_image=pytorch_image,
    source_code=source_code,
    base_job_name="unsloth-gemma3-4b-it",
    compute=compute,
    environment=env,
)

---
---
