In [None]:
# Copyright 2024 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

# Vertex AI Model Garden - Synthetic Data Generation using Llama 3.1

<table><tbody><tr>
  <td style="text-align: center">
    <a href="https://console.cloud.google.com/vertex-ai/colab/import/https:%2F%2Fraw.githubusercontent.com%2FGoogleCloudPlatform%2Fvertex-ai-samples%2Fmain%2Fnotebooks%2Fcommunity%2Fmodel_garden%2Fsynthetic_data_generation_using_llama3_1.ipynb">
      <img alt="Google Cloud Colab Enterprise logo" src="https://lh3.googleusercontent.com/JmcxdQi-qOpctIvWKgPtrzZdJJK-J3sWE1RsfjZNwshCFgE_9fULcNpuXYTilIR2hjwN" width="32px"><br> Run in Colab Enterprise
    </a>
  </td>
  <td style="text-align: center">
    <a href="https://github.com/GoogleCloudPlatform/vertex-ai-samples/blob/main/notebooks/community/model_garden/synthetic_data_generation_using_llama3_1.ipynb">
      <img alt="GitHub logo" src="https://cloud.google.com/ml-engine/images/github-logo-32px.png" width="32px"><br> View on GitHub
    </a>
  </td>
</tr></tbody></table>

## Overview

This notebook demonstrates generating synthetic data using the [Llama 3.1 405B service API](https://console.cloud.google.com/vertex-ai/publishers/meta/model-garden/llama3-405b-instruct-maas).


### Objective

Leverage the Llama 3.1 405B service API to gnerate synthetic data. The framework is based on [Snowfakery](https://snowfakery.readthedocs.io/en/latest/) which is itself based on [Faker](https://faker.readthedocs.io/en/master/). It requires the expected outputs to be codified in a YAML file per Snowfakery specs, detailing all the required fields and their respective data generation strategies.

### Costs

This tutorial uses billable components of Google Cloud:

* Vertex AI
* Cloud Storage

Learn about [Vertex AI pricing](https://cloud.google.com/vertex-ai/pricing), [Cloud Storage pricing](https://cloud.google.com/storage/pricing), and use the [Pricing Calculator](https://cloud.google.com/products/calculator/) to generate a cost estimate based on your projected usage.

## Steps

In [None]:
# @title Setup Google Cloud project

!pip install --upgrade --user -q openai snowfakery==3.6.2 wikipedia-api==0.6.0

import os
from datetime import datetime

from google.cloud import aiplatform

# Define project information
PROJECT_ID = os.environ["GOOGLE_CLOUD_PROJECT"]
REGION = os.environ["GOOGLE_CLOUD_REGION"]

# Enable the Vertex AI API and Compute Engine API, if not already.
print("Enabling Vertex AI API and Compute Engine API.")
! gcloud services enable aiplatform.googleapis.com compute.googleapis.com

# Cloud Storage bucket for storing the experiment artifacts.
# A unique GCS bucket will be created for the purpose of this notebook. If you
# prefer using your own GCS bucket, change the value yourself below.
now = datetime.now().strftime("%Y%m%d%H%M%S")
BUCKET_URI = "gs://"  # @param {type:"string"}

if BUCKET_URI is None or BUCKET_URI.strip() == "" or BUCKET_URI == "gs://":
    BUCKET_URI = f"gs://{PROJECT_ID}-tmp-{now}"
    ! gsutil mb -l {REGION} {BUCKET_URI}
else:
    assert BUCKET_URI.startswith("gs://"), "BUCKET_URI must start with `gs://`."
    BUCKET_NAME = "/".join(BUCKET_URI.split("/")[:3])
    shell_output = ! gsutil ls -Lb {BUCKET_NAME} | grep "Location constraint:" | sed "s/Location constraint://"
    bucket_region = shell_output[0].strip().lower()
    if bucket_region != REGION:
        raise ValueError(
            "Bucket region %s is different from notebook region %s"
            % (bucket_region, REGION)
        )
print(f"Using this GCS Bucket: {BUCKET_URI}")

STAGING_BUCKET = os.path.join(BUCKET_URI, "temporal")
MODEL_BUCKET = os.path.join(BUCKET_URI, "llama_3_1")


# Initialize Vertex AI API.
print("Initializing Vertex AI API.")
aiplatform.init(project=PROJECT_ID, location=REGION, staging_bucket=STAGING_BUCKET)

from google.colab import auth

auth.authenticate_user(project_id=PROJECT_ID)


import google.auth

# Programmatically get an access token
creds, _ = google.auth.default(
    scopes=["https://www.googleapis.com/auth/cloud-platform"]
)
auth_req = google.auth.transport.requests.Request()
creds.refresh(auth_req)
# Note: the credential lives for 1 hour by default (https://cloud.google.com/docs/authentication/token-types#at-lifetime); after expiration, it must be refreshed.

In [None]:
# @title Creating Plugins and Prompts

# @markdown The following cells create the 2 custom plugins we need for this use case along with the needed prompts.

import logging
import sys
import types
from io import StringIO

import jinja2
import openai
import wikipediaapi
from snowfakery import generate_data
from snowfakery.plugins import SnowfakeryPlugin

MODEL_ID = "meta/llama3-405b-instruct-maas"
ENDPOINT = "aiplatform.googleapis.com"

# Pass the Vertex endpoint and authentication to the OpenAI SDK
client = openai.OpenAI(
    base_url=f"https://us-central1-{ENDPOINT}/v1beta1/projects/{PROJECT_ID}/locations/{REGION}/endpoints/openapi",
    api_key=creds.token,
)


class SyntheticDataGeneration:
    """
    Implements all the extra functionality needed for this use-case
    """

    # The first plugin allows us to interact with the Llama 3.1 405B service API.
    class Plugins(types.ModuleType):
        """
        Provides the plugins needed to extend Snowfakery
        """

        class Llama3(SnowfakeryPlugin):
            """
            Plugin for interacting with Llama3 API service.
            """

            class Functions:
                """
                Functions to implement field / object level data generation
                """

                def fill_prompt(self, prompt_name: str, **kwargs) -> str:
                    """
                    Returns a formatted prompt
                    """
                    return (
                        jinja2.Environment(
                            loader=jinja2.FileSystemLoader(searchpath="./")
                        )
                        .get_template(prompt_name)
                        .render(**kwargs)
                    )

                def generate(
                    self,
                    prompt_name: str,
                    model="Llama3",
                    temperature=0.9,
                    top_p=1,
                    **kwargs,
                ) -> str | None:
                    """
                    A wrapper around Llama3 plugin
                    """
                    prompt = self.fill_prompt(prompt_name, **kwargs)
                    try:
                        response = client.chat.completions.create(
                            model=MODEL_ID,
                            messages=[{"role": "user", "content": prompt}],
                            temperature=temperature,
                            top_p=top_p,
                        )
                        return response.choices[0].message.content
                    except Exception as e:
                        logging.trace(
                            (
                                "Unable to generate text using %s.\n"
                                "Prepared Prompt: \n%s\n\nError: %s"
                            ),
                            prompt_name,
                            prompt,
                            e,
                        )
                        return None

        # The second plugin gives us the ability to interact with Wikipedia and fetch the contents for a given page.
        class Wikipedia(SnowfakeryPlugin):
            """
            Plugin for interacting with Wikipedia.
            """

            class Functions:
                """
                Implements a single function to fetch a Wikipedia page
                """

                def get_page(self, title: str):
                    """
                    Returns the title, URL and sections of the given wikipedia page
                    """
                    logging.info("Parsing Wikipedia Page %s", title)
                    page = wikipediaapi.Wikipedia(
                        "Snowfakery (example@google.com)", "en"
                    ).page(title)
                    results = {"sections": {}, "title": page.title, "url": page.fullurl}
                    sections = [(s.title, s) for s in page.sections]
                    while sections:
                        sec_title, sec_obj = sections.pop()
                        if sec_title in [
                            "External links",
                            "References",
                            "See also",
                            "Further reading",
                        ]:
                            continue
                        if sec_obj.text:
                            results["sections"][sec_title] = sec_obj.text
                        for sub_sec in sec_obj.sections:
                            sections.append((f"{sec_title} - {sub_sec.title}", sub_sec))
                    logging.info("Parsing Wikipedia Page %s Complete", title)
                    return results

In [None]:
# @title Making plugins discoverable

# @markdown We add the created class to sys.modules to ensure Snowfakery can find them and import them as modules as needed.

sys.modules["SyntheticDataGeneration.Plugins"] = SyntheticDataGeneration.Plugins(
    name="SyntheticDataGeneration.Plugins"
)

In [None]:
# @title Creating Prompt Templates

%%writefile blog_generator.jinja
You are an expert content creator who writes detailed, factual blogs.
You have been asked to write a blog about {{idea_title}}.
To get you started, you have also been given the following context about the topic:

{{idea_body}}

Ensure the blog that you write is interesting,detailed and factual.
Take a deep breath and start writing:

In [None]:
%%writefile comment_generator.jinja
You are {{first_name}} {{last_name}}. You are {{age}} years old. You are interested in {{interests}}. You work at {{organization}} as a {{profession}}.
You came across the following article:

{{blog_title}}

{{blog_body}}

Present your thoughts and feelings about the article in a short comment.

Comment:

In [None]:
# @title Creating the Recipe

# @markdown In order to generate synthetic data, the schema of the synthetic data must be defined first. This is done by creating a recipe in a YAML format as demonstrated below, more details on writing recipes can be found [here](https://snowfakery.readthedocs.io/en/latest/#central-concepts).

recipe = """
- plugin: SyntheticDataGeneration.Plugins.Wikipedia
- plugin: SyntheticDataGeneration.Plugins.Llama3
- option: wiki_title
- var: __seed
  value:
    - Wikipedia.get_page :
      title : ${{wiki_title}}

- object : users
  count : ${{random_number(min=100, max=500)}}
  fields :
    first_name : ${{fake.FirstName}}
    last_name : ${{fake.FirstName}}
    age:
      random_number:
        min: 18
        max: 95
    email : ${{fake.Email}}
    phone : ${{fake.PhoneNumber}}
    interests : ${{fake.Bs}}
    postal_code : ${{fake.Postalcode}}
    organization : ${{fake.Company}}
    profession : ${{fake.Job}}

- object : seeds
  fields :
    title : ${{__seed['title']}}
    url : ${{__seed['url']}}
    section_count : ${{__seed['sections'] | length}}

  friends:
    - object : blog_ideas
      count : ${{seeds.section_count}}
      fields :
        seed_id : ${{seeds.id}}
        section : ${{(__seed.sections.keys() | list)[child_index]}}
        body : ${{__seed.sections[section]}}

      friends:
        - object : blog_posts
          fields :
            blog_idea_id : ${{blog_ideas.id}}
            title : ${{seeds.title}} - ${{blog_ideas.section}}
            body :
              - Llama3.generate:
                prompt_name : blog_generator.jinja
                idea_title : ${{title}}
                idea_body : ${{blog_ideas.body}}
            author : Llama3

          friends:
            - object : blog_post_comments
              fields :
                blog_post_id : ${{blog_posts.id}}
                author_id :
                  random_reference : users
                author_email : ${{author_id.email}}
                comment :
                  - Llama3.generate:
                    prompt_name : comment_generator.jinja
                    first_name : ${{author_id.first_name}}
                    last_name : ${{author_id.last_name}}
                    age : ${{author_id.age}}
                    interests : ${{author_id.interests}}
                    organization : ${{author_id.organization}}
                    profession : ${{author_id.profession}}
                    blog_title : ${{blog_posts.title}}
                    blog_body : ${{blog_posts.body | truncate(1000)}}
"""

In [None]:
# @title Generating Data

generate_data(
    StringIO(recipe),
    output_format="csv",
    output_folder="outputs",
    user_options={"wiki_title": "Python_(programming_language)"},
)

# @markdown Results The synthetic data has been generated and stored as CSV files in the `outputs` folder.