Copyright 2024 Google LLC.

SPDX-License-Identifier: Apache-2.0

This Colab example initiates a tuning job on base model `translation-llm-002` with provided tsv, csv or tmx datasets.

##Prerequisites

*   Enable Vertex AI API in APIs & services page.
*   It's recommended to use Colab for running the script, as it best supports the authentication process and Cloud CLI.

##Authentication

In [None]:
from google.colab import auth

PROJECT_ID = "my-project"  # @param {type:"string"}
auth.authenticate_user(project_id=PROJECT_ID)

In [None]:
!gcloud config set project {PROJECT_ID}

##Input parameters

To quick start: Upload your input datasets to Colab. Fill in required parameters.

By default, the model name to be used for translate text requests will be returned after the tuning finishes. For your reference, the tuning will take less than 20 minutes for a dataset with 10k training examples.

In [None]:
# Directory to save converted dataset.
GCS_UPLOAD_PATH = 'gs://my_bucket/dir' # @param {type:"string"}

# Model display name on Vertex AI Online Prediction page.
TUNED_MODEL_DISPLAY_NAME = 'translation-llm-test' # @param {type:"string"}

SOURCE_LANGUAGE_CODE = 'en' # @param {type:"string"}
TARGET_LANGUAGE_CODE = 'es' # @param {type:"string"}

# Colab path for train/validation data.
# tsv, csv and tmx are supported.
TRAIN_FILE_PATH = '/content/train.tsv' # @param {type:"string"}
VALIDATION_FILE_PATH = '/content/validation.tsv' # @param {type:"string"}

# Set sample size. Set to "-1" to use all examples.
TRAIN_DATASET_SAMPLE_SIZE = -1 # @param {type:"integer"}

# Validation size limit is 1000.
VALIDATION_DATASET_SAMPLE_SIZE = 250 # @param {type:"integer"}


##Helper functions

In [None]:
%pip install --upgrade translate-toolkit

In [None]:
# only us-central1 is supported for now
LOCATION = 'us-central1'

language_map = {
    'en' : 'English',
    'es' : 'Spanish',
    'fr' : 'French',
    'de' : 'German',
    'it' : 'Italian',
    'pt' : 'Portuguese',
    'zh' : 'Chinese',
    'ja' : 'Japanese',
    'ko' : 'Korean',
    'ar' : 'Arabic',
    'hi' : 'Hindi',
    'ru' : 'Russian',
}

In [None]:
import csv
import json
import glob
import os
import time
from translate.storage.tmx import tmxfile

from google.cloud import translate_v3
from google.cloud import storage

import vertexai
from vertexai.tuning import sft


# Creates single json tuning input data
def convert_line_to_jsonl(source_language, target_language, source_sentence, target_sentence):
  return json.dumps({
      "contents": [{"role": "user", "parts": [{"text": source_language + ": " + source_sentence + " " + target_language + ": "}]},
       {"role": "model", "parts": [{"text": target_sentence}]}]}, ensure_ascii=False)


# Format conversion function for single file input. Output file will have the same name but in jsonl format.
def convert_file_format(input_file, sample_size, source_language_code, target_language_code):
  if source_language_code not in language_map or target_language_code not in language_map:
    raise ValueError("Invalid language code")
  name, ext = os.path.splitext(input_file)
  output_file = name + '.jsonl'
  if input_file.endswith('.tsv'):
    with open(input_file, 'r', encoding='utf-8') as infile, \
      open(output_file, 'w', encoding='utf-8') as outfile:
        reader = csv.reader(infile, delimiter='\t')
        for i, row in enumerate(reader):
          if i == sample_size:
            break
          message = convert_line_to_jsonl(language_map[source_language_code], language_map[target_language_code], row[0], row[1])
          outfile.write(message)
          outfile.write('\n')

  elif input_file.endswith('.csv'):
    with open(input_file, 'r', encoding='utf-8') as infile, \
      open(output_file, 'w', encoding='utf-8') as outfile:
        reader = csv.reader(infile)
        for i, row in enumerate(reader):
          if i == sample_size:
            break
          message = convert_line_to_jsonl(language_map[source_language_code], language_map[target_language_code], row[0], row[1])
          outfile.write(message)
          outfile.write('\n')

  elif input_file.endswith('.tmx'):
    with open(input_file, 'rb') as infile, \
      open(output_file, 'w', encoding='utf-8') as outfile:
        tmx_file = tmxfile(infile, 'source_language_code', 'target_language_code')
        for i, node in enumerate(tmx_file.unit_iter()):
          if i == sample_size:
            break
          message = convert_line_to_jsonl(language_map[source_language_code], language_map[target_language_code], node.source, node.target)
          outfile.write(message)
          outfile.write('\n')

  else:
    raise ValueError("Invalid file type")

  return output_file


# Initiates model training
def train_model(train_dataset_path, validation_dataset_path, tuned):
  vertexai.init(project=PROJECT_ID, location=LOCATION)

  sft_tuning_job = sft.train(
    source_model="translation-llm-002",
    train_dataset=train_dataset_path,
    validation_dataset=validation_dataset_path,
    tuned_model_display_name=TUNED_MODEL_DISPLAY_NAME,
  )

  # Polling for job completion
  while not sft_tuning_job.has_ended:
    time.sleep(60)
    sft_tuning_job.refresh()

  endpoint_short_name = sft_tuning_job.tuned_model_endpoint_name.rsplit('/', 1)[-1]
  custom_model_name = f"projects/{PROJECT_ID}/locations/{LOCATION}/models/translation-llm-custom/{endpoint_short_name}"

  print("Model: ", custom_model_name)
  return custom_model_name


##Dataset Format Conversion

This step convers data to `.jsonl` format for tuning.

In [None]:
train_jsonl = convert_file_format(TRAIN_FILE_PATH, TRAIN_DATASET_SAMPLE_SIZE, SOURCE_LANGUAGE_CODE, TARGET_LANGUAGE_CODE)
validation_jsonl = convert_file_format(VALIDATION_FILE_PATH, VALIDATION_DATASET_SAMPLE_SIZE, SOURCE_LANGUAGE_CODE, TARGET_LANGUAGE_CODE)
GCS_UPLOAD_PATH = GCS_UPLOAD_PATH.rstrip('/')

In [None]:
!gsutil cp {train_jsonl} {GCS_UPLOAD_PATH}
!gsutil cp {validation_jsonl} {GCS_UPLOAD_PATH}

## Initiate Vertex Tuning Request

After tuning is done, the translation model name will be returned to be used for translation requests.

In [None]:
train_dataset_path = GCS_UPLOAD_PATH + '/' + os.path.basename(train_jsonl)
validation_dataset_path = GCS_UPLOAD_PATH + '/' + os.path.basename(validation_jsonl)
custom_model_name = train_model(train_dataset_path, validation_dataset_path, TUNED_MODEL_DISPLAY_NAME)