# Fine-Tuning HuggingFace Models Using Amazon SageMaker
In this example we take a look at how we can take a sample [BERT model](https://huggingface.co/google-bert/bert-base-cased) and fine-tune it using SageMaker Training Jobs. With SageMaker Training Jobs, containers and infra layer becomes managed and also stitches in nicely with deploying to an endpoint.

## Additional Resources/Credits
- Learned a lot from PhilSchmid's example: https://github.com/huggingface/notebooks/tree/main/sagemaker/01_getting_started_pytorch
- Trainer Documentation: https://huggingface.co/docs/transformers/en/main_classes/trainer
- <b>NOTE</b> -> Checkpoints SM: https://docs.aws.amazon.com/sagemaker/latest/dg/model-checkpoints-enable.html, in this example we save checkpoints as part of the model data which I wouldn't recommend (adds to model tarball size). Refer to these docs to decouple the checkpoints to another channel.


## Setup
Working on a conda_python3 kernel, you can utilize any base instance here for the most part as the infra needed for training will be supplied via the Training Job.

In [None]:
!pip install sagemaker datasets evaluate transformers[torch] accelerate>=0.26.0 --quiet

## Config Setup

In [None]:
import os
import boto3
import pandas as pd
import sagemaker
from sagemaker.huggingface import HuggingFace
from sagemaker import get_execution_role
from sklearn.model_selection import train_test_split
session = sagemaker.Session()
role = get_execution_role()  # works in SageMaker Notebook/Studio

bucket = session.default_bucket()      # or set your own bucket name
prefix = "bert-intro-imdb"             # S3 prefix for this job

os.makedirs("data", exist_ok=True)
local_train_csv = "data/train.csv"
local_test_csv = "data/test.csv"

## Dataset Setup
We push dataset into S3 as a CSV, we conduct tokenization within the training script, but it can also happen on the client side depending on what you prefer.

In [None]:
from datasets import load_dataset
dataset = load_dataset("imdb")  # already has train / test splits

train_df = dataset["train"].to_pandas()
test_df = dataset["test"].to_pandas()

# keep columns "text" and "label" as your script expects
train_df[["text", "label"]].to_csv(local_train_csv, index=False)
test_df[["text", "label"]].to_csv(local_test_csv, index=False)

print("Wrote train.csv and test.csv from HF IMDb dataset")

In [None]:
s3 = boto3.client("s3")

s3.upload_file(local_train_csv, bucket, f"{prefix}/train/train.csv")
s3.upload_file(local_test_csv,  bucket, f"{prefix}/test/test.csv")

s3_train = f"s3://{bucket}/{prefix}/train"
s3_test  = f"s3://{bucket}/{prefix}/test"

print("Uploaded to:")
print("  ", s3_train)
print("  ", s3_test)

## Define HuggingFace Estimator
With this estimator we automatically pull the needed training container by specifying transformers and torch version. You can also specify other hyperparameters, we keep pretty minimal in this case, as you get more advanced for distributed training for example you specify it in this map.

In [None]:
hyperparameters = {
    "model_name": "bert-base-cased",
    "num_train_epochs": 1,
}

estimator = HuggingFace(
    entry_point="train.py",           # your script
    source_dir="./scripts",                   # folder containing train.py
    instance_type="ml.g5.12xlarge",   # or g5/g6 etc
    instance_count=1,
    role=role,
    transformers_version="4.46",
    pytorch_version="2.3",
    py_version="py311",
    hyperparameters=hyperparameters,
)

## Training Job

In [None]:
estimator.fit({
    "train": s3_train,   # -> SM_CHANNEL_TRAIN
    "test":  s3_test,    # -> SM_CHANNEL_TEST
})

## Extract Model Data

In [None]:
estimator.model_data

In [None]:
from urllib.parse import urlparse

model_s3_uri = estimator.model_data   # e.g., s3://my-bucket/path/to/model.tar.gz

parsed = urlparse(model_s3_uri)
bucket = parsed.netloc
key = parsed.path.lstrip("/")   # remove leading "/"

import boto3
s3 = boto3.client("s3")
s3.download_file(bucket, key, "model.tar.gz")

In [None]:
import tarfile
with tarfile.open("model.tar.gz") as tar:
    tar.extractall("model_dir")