### 1. Setup development environment

In [None]:
# %pip install --upgrade --quiet boto3 sagemaker huggingface datasets plotly

In [None]:
import json, boto3, sagemaker

dataset_id = 'deepmind/code_contests'
model_id = "mistral-community/Codestral-22B-v0.1"
base_job_name = "fsdp-codestral"
workspace_bucket_name = "research-agi"
s3_prefix = "mistral-community-codestral-22b-v0x1"
s3_train_dataset_path = f"s3://{workspace_bucket_name}/{s3_prefix}/train"
s3_test_dataset_path = f"s3://{workspace_bucket_name}/{s3_prefix}/test"
s3_save_model_dir = f"s3://{workspace_bucket_name}/{s3_prefix}/runs/"

role = sagemaker.get_execution_role()
session = sagemaker.session.Session(default_bucket=workspace_bucket_name)
region = session._region_name

### 2. Create and prepare dataset

In [None]:
from utils import data_utils

In [None]:
# load and save train dataset
train_dataset = data_utils.load_and_process(
    dataset_id=dataset_id,
    split="train[:60%]"
)
print(f"train_dataset: {train_dataset}")
train_dataset.save_to_disk(s3_train_dataset_path)
print(f"s3_train_dataset_path: {s3_train_dataset_path}")

In [None]:
# load and save test dataset
test_dataset = data_utils.load_and_process(
    dataset_id=dataset_id,
    split="test"
)
print(f"test_dataset: {test_dataset}")
test_dataset.save_to_disk(s3_test_dataset_path)
print(f"s3_test_dataset_path: {s3_test_dataset_path}")

In [None]:
import plotly.express as px

def plot(tokenized_train_dataset, tokenized_test_dataset):
    lengths = [len(x["input_ids"]) for x in tokenized_train_dataset]
    lengths += [len(x["input_ids"]) for x in tokenized_test_dataset]

    fig = px.histogram(lengths)
    fig.show()

In [None]:
from transformers import AutoTokenizer

tokenizer = AutoTokenizer.from_pretrained(model_id)
tokenizer.pad_token = tokenizer.eos_token
tokenizer.padding_side = "right"

def template_dataset(examples):
    return { 
            "input_ids": tokenizer.apply_chat_template(
                                examples["messages"], 
                                tokenize=False,
                                # truncation=True,
                                # max_length=4096, 
    )}

tokenized_train_dataset = train_dataset.map(template_dataset, remove_columns=["messages"])
tokenized_test_dataset = test_dataset.map(template_dataset, remove_columns=["messages"])

In [None]:
plot(tokenized_train_dataset, tokenized_test_dataset)

### 3. Set arguments

In [None]:
hyperparameters = {
    ### training related
    "dataset_path": "/opt/ml/input/data",
    "sm_save_model_dir": "/opt/ml/model",
    "output_dir":  "/tmp", 
    "logging_dir": "/tmp/logs",
    
    "model_id": "mistral-community/Codestral-22B-v0.1",
    "num_train_epochs": 1,
    "max_steps": -1,
    "per_device_train_batch_size": 1,
    "per_device_eval_batch_size": 1,
    "gradient_accumulation_steps": 1,
    "gradient_checkpointing": True,
    "gradient_checkpointing_kwargs": {
        "use_reentrant": False,
    },  
    "bf16": True,
    "tf32": True,
    "max_grad_norm": 0.3,
    "weight_decay": 0.001,
    "optim": "adamw_torch",
    "learning_rate": 0.0002,
    "warmup_ratio": 0.03,
    "lr_scheduler_type": "constant",
    "save_strategy": "no",
    "logging_steps": 25,
    "logging_strategy": "steps",
    "group_by_length": True,
    "max_seq_length": 4096,
    "packing": False,
    "finetune_with_sm": True,
    "merge_weights_and_save": True,
    "save_tokenizer": True,
    "attn_implementation": "sdpa",

    ### qlora related
    "lora_r": 64,
    "lora_alpha": 16,
    "lora_dropout": 0.1, 
    "task_type": "CAUSAL_LM",

    ### bitsandbytes related
    "load_in_4bit": True,
    "bnb_4bit_use_double_quant": True,
    "bnb_4bit_quant_type": "nf4",
    "bnb_4bit_compute_dtype": "bfloat16",
    "bnb_4bit_quant_storage": "bfloat16", 
}

print('Hyperparameters: \n', json.dumps(hyperparameters, indent=2, default=str))

### 4. Begin training!

In [None]:
from sagemaker.pytorch import PyTorch

estimator = PyTorch(
    source_dir                   = "./scripts",
    entry_point                  = "sft_fsdp_qlora.py",
    base_job_name                = base_job_name,
    role                         = role,
    sagemaker_session            = session,
    framework_version            = "2.3.0",
    py_version                   = "py311", 
    instance_count               = 1,
    instance_type                = "ml.p4d.24xlarge", # gpus=8
    volume_size                  = 300,
    max_run                      = 1*24*60*60, # days * hours * minutes * seconds
    hyperparameters              = hyperparameters,
    disable_profiler             = True,
    keep_alive_period_in_seconds = 1800,
    debugger_hook_config         = False,
    distribution                 = {"torch_distributed": {"enabled": True}}, # enable torchrun
    environment                  = {"HUGGINGFACE_HUB_CACHE": "/tmp/.cache"},
    disable_output_compression   = True,
    output_path                  = s3_save_model_dir,
)

data = {
    'train': s3_train_dataset_path,
    'test' : s3_test_dataset_path,
}

print(f"training_image_uri: {estimator.training_image_uri()}")
print(f"data: {json.dumps(data, indent=2, default=str)}")

In [None]:
%%time
estimator.fit(data, wait=True)

In [None]:
print(f"estimator.model_data: {estimator.model_data}")