In [1]:
from google.cloud import aiplatform
import os
from datetime import datetime

In [2]:
PROJECT_ID = "mlops-explorations"
REGION = "us-central1"
BUCKET_URI = "gs://jax-fine-tuning-gemma"
SERVICE_ACCOUNT = "939282436854-compute@developer.gserviceaccount.com"
jsonl_dataset_file = "https://huggingface.co/datasets/databricks/databricks-dolly-15k/resolve/main/databricks-dolly-15k.jsonl"
# TODO(davidnet) Copy the dataset to the bucket
jsonl_dataset_uri = f"{BUCKET_URI}/dataset.jsonl"
jsonl_dataset_uri_gcsfuse = jsonl_dataset_uri.replace("gs://", "/gcs/")
KERAS_TRAIN_DOCKER_URI = "us-docker.pkg.dev/vertex-ai/vertex-vision-model-garden-dockers/jax-keras-train-tpu:20240220_0936_RC01"
STAGING_BUCKET = os.path.join(BUCKET_URI, "temporal")
aiplatform.init(project=PROJECT_ID, location=REGION, staging_bucket=STAGING_BUCKET)

In [3]:
def get_job_name_with_datetime(prefix: str) -> str:
    """Gets the job name with date time when triggering training or deployment
    jobs in Vertex AI.
    """
    return prefix + datetime.now().strftime("_%Y%m%d_%H%M%S")

In [4]:
template = "Instruction: {instruction} Response: {response}"
tfds_dataset_name = ""
tfds_dataset_split = "train"



In [5]:
# The Gemma base model.
KAGGLE_MODEL_ID = "gemma_2b_en"  
num_train_subset_samples = 2000  

num_epochs = 1 
# Learning rate.
learning_rate = 5e-5  
# Weight decay.
weight_decay = 0.01 
# Input sequence length. It determines the memory required by the model.
input_sequence_length = 512
# LoRA rank.
lora_rank = 4
# Batch size for training.
train_batch_size = 2
# The KerasNLP checkpoint filename.
# Note: Do not add folder name here.
checkpoint_filename = "fine_tuned.weights.h5"

# Worker pool spec.
machine_type = "cloud-tpu"
# NOTE: The models have been test only with 8 cores.
accelerator_type = "TPU_V3"
# Number of TPU cores.
accelerator_count = 8
# Set model parallelism related parameters for 8 cores.
model_parallel_batch_dim = 1
model_parallel_model_dim = 8

replica_count = 1

# Setup training job.
job_name = get_job_name_with_datetime("gemma-keras-lora-train")

# Pass training arguments and launch job.
train_job = aiplatform.CustomContainerTrainingJob(
    display_name=job_name,
    container_uri=KERAS_TRAIN_DOCKER_URI,
)

# Create a GCS folder to save the finetuned model.
output_folder = os.path.join(BUCKET_URI, job_name)
output_folder_gcsfuse = output_folder.replace("gs://", "/gcs/")

train_job.run(
    args=[
        f"--model_type={KAGGLE_MODEL_ID}",
        f"--num_epochs={num_epochs}",
        f"--learning_rate={learning_rate}",
        f"--weight_decay={weight_decay}",
        f"--input_sequence_length={input_sequence_length}",
        f"--lora_rank={lora_rank}",
        f"--model_parallel_batch_dim={model_parallel_batch_dim}",
        f"--model_parallel_model_dim={model_parallel_model_dim}",
        f"--tfds_dataset_name={tfds_dataset_name}",
        f"--tfds_dataset_split={tfds_dataset_split}",
        f"--jsonl_dataset_file={jsonl_dataset_uri_gcsfuse}",
        f"--template={template}",
        f"--train_batch_size={train_batch_size}",
        f"--num_train_subset_samples={num_train_subset_samples}",
        f"--output_folder={output_folder_gcsfuse}",
        f"--checkpoint_filename={checkpoint_filename}",
    ],
    environment_variables={
        "KAGGLE_USERNAME": "davidnet",
        "KAGGLE_KEY": "Don't Tread on Me :)",
    },
    replica_count=replica_count,
    machine_type=machine_type,
    accelerator_type=accelerator_type,
    accelerator_count=accelerator_count,
    service_account=SERVICE_ACCOUNT,
)

print("Trained model is saved in: ", output_folder)


Training Output directory:
gs://jax-fine-tuning-gemma/temporal/aiplatform-custom-training-2024-04-08-03:30:45.873 
View Training:
https://console.cloud.google.com/ai/platform/locations/us-central1/training/5120172797238181888?project=939282436854
View backing custom job:
https://console.cloud.google.com/ai/platform/locations/us-central1/training/6002343961551699968?project=939282436854
CustomContainerTrainingJob projects/939282436854/locations/us-central1/trainingPipelines/5120172797238181888 current state:
PipelineState.PIPELINE_STATE_RUNNING
CustomContainerTrainingJob projects/939282436854/locations/us-central1/trainingPipelines/5120172797238181888 current state:
PipelineState.PIPELINE_STATE_RUNNING
CustomContainerTrainingJob projects/939282436854/locations/us-central1/trainingPipelines/5120172797238181888 current state:
PipelineState.PIPELINE_STATE_RUNNING
CustomContainerTrainingJob projects/939282436854/locations/us-central1/trainingPipelines/5120172797238181888 current state:
Pipe