In [None]:
!pip install --upgrade --user --quiet google-cloud-aiplatform datasets backoff multiprocess gcsfs


In [None]:
!pip install datasets

In [100]:
from collections import Counter
import json
from typing import Any, Callable, Dict, List, Optional, Union
import io
# Data Handling and Processing
from datasets import load_dataset
from sklearn.metrics import (
    classification_report,
    f1_score,
    precision_score,
    recall_score,
)
from sklearn.model_selection import train_test_split
import pandas as pd
import gcsfs
from google.cloud import storage

# Google Cloud Libraries
from google.api_core.exceptions import ResourceExhausted
from google.cloud import aiplatform
import vertexai
from vertexai.generative_models import (
    GenerativeModel,
    GenerationConfig,
    HarmBlockThreshold,
    HarmCategory,
)
from vertexai.preview.tuning import sft

# Multiprocessing
import multiprocess as mp
from tqdm import tqdm
import backoff
from concurrent.futures import ThreadPoolExecutor
from functools import partial
import traceback

In [3]:
data_file = "train_40k.csv"

In [4]:
data = pd.read_csv(data_file)

In [5]:
columns_to_drop = [
    "productId",
    "Title",
    "userId",
    "Helpfulness",
    "Score",
    "Time",
    "Cat1",
]  # List of columns to drop
data = data.drop(columns=columns_to_drop)

In [7]:

cat2_classes = data["Cat2"].unique()
print(len(cat2_classes))

cat3_classes = data["Cat3"].unique()
print(len(cat3_classes))

64
464


In [8]:
# Split the dataset into train and test sets (80% train, 20% test)
train_df, test_df = train_test_split(data, test_size=0.2, random_state=42)

# Save the test data into another CSV file
train_df.to_csv("train.csv", index=False)
test_df.to_csv("train.csv", index=False)

In [9]:
PROJECT_ID = "edl-idaas-rnd-platform-d5ae"  # @param {type:"string"}
LOCATION = "us-central1"  # @param {type:"string"}

vertexai.init(project=PROJECT_ID, location=LOCATION)

In [10]:
BUCKET_NAME = "sridhanya_edl-idaas-rnd-platform"  # @param {type:"string"}
BUCKET_URI = f"gs://{BUCKET_NAME}"

In [11]:
# !gsutil mb -l $LOCATION -p $PROJECT_ID $BUCKET_URI


In [12]:
def backoff_hdlr(details) -> None:
    """
    Handles backoff events.

    Args:
        details: A dictionary containing information about the backoff event.
    """
    print(f"Backing off {details['wait']:.1f} seconds after {details['tries']} tries")


def log_error(msg: str, *args: Any) -> None:
    """
    Logs an error message and raises an exception.

    Args:
        msg: The error message.
        *args: Additional arguments to be passed to the logger.
    """
    mp.get_logger().error(msg, *args)
    raise Exception(msg)


def handle_exception_threading(f: Callable) -> Callable:
    """
    A decorator that handles exceptions in a threaded environment.

    Args:
        f: The function to decorate.

    Returns:
        The decorated function.
    """

    def applicator(*args: Any, **kwargs: Any) -> Any:
        try:
            return f(*args, **kwargs)
        except:
            log_error(traceback.format_exc())

    return applicator


@handle_exception_threading
@backoff.on_exception(
    backoff.expo, ResourceExhausted, max_tries=30, on_backoff=backoff_hdlr
)
def _predict_message(message: str, model: GenerativeModel) -> Optional[str]:
    """
    Predict messages

    Args:
        message: The message to predict.
        model: The GenerativeModel to use for prediction.

    Returns:
        The predicted message, or None if an error occurred.
    """
    response = model.generate_content([message], stream=False)
    return response.text


def batch_predict(
    messages: List[str], model: GenerativeModel, max_workers: int = 4
) -> List[Optional[str]]:
    """
    Predicts the classes for a list of messages

    Args:
        - messages: list of all messages to predict
        - model: model to use for predicting.
        - max_workers: number of workers to use for parallel predictions

    Returns:
        - list of predicted labels

    """
    predictions = list()
    with ThreadPoolExecutor(max_workers) as pool:
        partial_func = partial(_predict_message, model=model)
        for message in tqdm(pool.map(partial_func, messages), total=len(messages)):
            predictions.append(message)
            pass

    return predictions

In [13]:
class VertexAIExperimentManager:
    """
    A class for managing experiments and runs in Vertex AI.
    This class encapsulates the functionality for creating experiments, logging runs,
    and retrieving experiment data in Vertex AI.
    """

    def __init__(self, project: str, location: str):
        self.project = project
        self.location = location
        self.current_experiment = None

    def init_experiment(
        self, experiment_name: str, experiment_description: Optional[str] = None
    ):
        """Initialize or switch to a specific experiment."""
        self.current_experiment = experiment_name
        aiplatform.init(
            experiment=experiment_name,
            experiment_description=experiment_description,
            experiment_tensorboard=False,
            project=self.project,
            location=self.location,
        )

    def create_experiment(
        self, experiment_name: str, experiment_description: Optional[str] = None
    ) -> None:
        """Create an Experiment on Vertex AI Experiments"""
        self.init_experiment(experiment_name, experiment_description)

    def log_run(
        self, run_name: str, params: Dict[str, Any], metrics: Dict[str, Any]
    ) -> None:
        """Log experiment run data to Vertex AI Experiments."""
        if not self.current_experiment:
            raise ValueError("No experiment initialized. Call init_experiment first.")

        aiplatform.start_run(run=run_name)
        aiplatform.log_params(params)
        aiplatform.log_metrics(metrics)
        aiplatform.end_run()

    def get_experiments_data_frame(self) -> Optional[pd.DataFrame]:
        """Retrieve a DataFrame of experiment data from Vertex AI Experiments."""
        if not self.current_experiment:
            raise ValueError("No experiment initialized. Call init_experiment first.")

        return aiplatform.get_experiment_df()

In [120]:
def create_gemini_messages(
    text: str, label: str, system_prompt: Optional[str] = None
) -> dict:
    messages = []
    if system_prompt:
        messages.append({"role": "system", "content": system_prompt})
    messages.extend(
        [
            {"role": "user", "content": Text},
            {"role": "model", "content": Cat3},
        ]
    )
    return {"messages": messages}


def prepare_tuning_dataset_from_df(
    tuning_df: pd.DataFrame, system_prompt: Optional[str] = None
) -> pd.DataFrame:
    """
    Prepares a tuning dataset from a pandas DataFrame for Gemini fine-tuning.
    Args:
        tuning_df: A pandas DataFrame with columns "text" and "label_text".
        system_prompt: An optional system prompt for zero-shot learning.
    Returns:
        A pandas DataFrame containing the data in the Gemini tuning format.
    """
    tuning_dataset = [
        create_gemini_messages(row["Text"], row["Cat3"], system_prompt)
        for _, row in tuning_df.iterrows()
    ]
    return pd.DataFrame(tuning_dataset)


def convert_tuning_dataset_from_automl_csv(
    automl_gcs_csv_path: str,
    system_prompt: Optional[str] = None,
    partition: str = "training",
) -> pd.DataFrame:
    """
    Converts an AutoML CSV dataset for text classification to the Gemini tuning format.
    Args:
        automl_gcs_csv_path: The GCS path to the AutoML CSV dataset.
        system_prompt: The instructions to the model.
        partition: The partition to extract from the dataset (e.g., "training", "validation", "test"). Defaults to "training".
    Returns:
        A pandas DataFrame containing the data in the Gemini tuning format.
    """
    df = pd.read_csv(automl_gcs_csv_path, names=["partition", "Text", "Cat3"])
    df_automl = df.loc[df["partition"] == partition]
    gemini_dataset = [
        create_gemini_messages(row["Text"], row["Cat3"], system_prompt)
        for _, row in df_automl.iterrows()
    ]
    return pd.DataFrame(gemini_dataset)


def convert_tuning_dataset_from_automl_jsonl(
    project_id: str,
    automl_gcs_jsonl_path: str,
    system_prompt: Optional[str] = None,
    partition: str = "training",
) -> pd.DataFrame:
    """
    Converts an AutoML JSONL dataset for text classification to the Gemini tuning format.
    Args:
        automl_gcs_jsonl_path: The GCS path to the AutoML JSONL dataset for text classification.
        system_prompt: The instructions to the model.
        partition: The partition to extract from the dataset (e.g., "training", "validation", "test"). Defaults to "training".
    Returns:
        A pandas DataFrame containing the data in the Gemini tuning format.
    """
    processed_data = []
    gcs_file_system = gcsfs.GCSFileSystem(project=project_id)
    with gcs_file_system.open(automl_gcs_jsonl_path) as f:
        for line in f:
            data = json.loads(line)
            processed_data.append(
                {
                    "Cat3": data["classificationAnnotation"]["displayName"],
                    "Text": data["textContent"],
                    "partition": data["dataItemResourceLabels"][
                        "aiplatform.googleapis.com/ml_use"
                    ],
                }
            )

    df = pd.DataFrame(processed_data)
    df_automl = df.loc[df["partition"] == partition]
    gemini_dataset = [
        create_gemini_messages(row["Text"], row["Cat3"], system_prompt)
        for _, row in df_automl.iterrows()
    ]
    return pd.DataFrame(gemini_dataset)


def validate_gemini_tuning_jsonl(gcs_jsonl_path: str) -> List[Dict]:
    """
    Validates a JSONL file on Google Cloud Storage against the Gemini tuning format.

    Args:
        gcs_jsonl_path: The GCS path to the JSONL file.

    Returns:
        A list of dictionaries representing the errors found in the file.
        Each dictionary has the following structure:
        {
            "error_type": "Error description",
            "row_index": The index of the row where the error occurred,
            "message": The error message
        }
    """

    errors = []
    storage_client = storage.Client()
    blob = storage.Blob.from_string(uri=gcs_jsonl_path, client=storage_client)

    with blob.open("r") as f:
        for row_index, line in enumerate(f):
            try:
                data = json.loads(line)
                # Check for the presence of the "messages" key
                if "messages" not in data:
                    errors.append(
                        {
                            "error_type": "Missing 'messages' key",
                            "row_index": row_index,
                            "message": f"Row {row_index} is missing the 'messages' key.",
                        }
                    )
                    continue

                messages = data["messages"]
                # Check if "messages" is a list
                if not isinstance(messages, list):
                    errors.append(
                        {
                            "error_type": "Invalid 'messages' type",
                            "row_index": row_index,
                            "message": f"Row {row_index}: 'messages' is not a list.",
                        }
                    )
                    continue

                # Validate each message in the "messages" list
                for message_index, message in enumerate(messages):
                    if not isinstance(message, dict):
                        errors.append(
                            {
                                "error_type": "Invalid message format",
                                "row_index": row_index,
                                "message": f"""Row {row_index},
                            message {message_index}: Message is not a dictionary.""",
                            }
                        )
                        continue

                    # Check for required keys in each message dictionary
                    if "role" not in message or "content" not in message:
                        errors.append(
                            {
                                "error_type": "Missing 'role' or 'content' key",
                                "row_index": row_index,
                                "message": f"Row {row_index}, message {message_index}: "
                                "Missing 'role' or 'content' key.",
                            }
                        )
                        continue

                    # Check for valid role values
                    if message["role"] not in ["system", "user", "model"]:
                        errors.append(
                            {
                                "error_type": "Invalid 'role' value",
                                "row_index": row_index,
                                "message": f"""Row {row_index}, message {message_index}:
                            Invalid 'role' value. Expected 'system', 'user', or 'model'.""",
                            }
                        )
                        continue

            except json.JSONDecodeError as e:
                errors.append(
                    {
                        "error_type": "JSON Decode Error",
                        "row_index": row_index,
                        "message": f"Row {row_index}: JSON decoding error: {e}",
                    }
                )

    return errors

In [16]:
len(train_df)

32000

In [17]:
len(test_df)

8000

In [19]:
train_df.head()

Unnamed: 0,Text,Cat2,Cat3
14307,"The concept of this toy is good. However, if y...",dogs,toys
17812,"This dryer ruined my hair!!! At first, after I...",hair care,styling tools
11020,Much to my surprise after a year of waiting th...,novelty gag toys,miniatures
15158,The tree is beautiful but upon arrival when I ...,fresh flowers live indoor plants,live indoor plants
24990,Watchmaker offered to install a new battery in...,household supplies,unknown


In [21]:
train_df.Cat2.value_counts()


Cat2
personal care         2294
dogs                  2092
nutrition wellness    1780
health care           1614
cats                  1428
                      ... 
produce                 33
baby food               32
sauces dips             32
meat seafood            25
small animals           22
Name: count, Length: 64, dtype: int64

In [22]:
train_df.Cat3.value_counts()


Cat3
unknown                 1832
shaving hair removal    1238
vitamins supplements    1071
board games              738
styling tools            670
                        ... 
fruit gifts                1
foie gras p t s            1
children s                 1
aprons smocks              1
pork                       1
Name: count, Length: 451, dtype: int64

In [29]:
filtered_df = test_df.groupby("Cat3").filter(lambda x: len(x) > 1)
val, test = train_test_split(
    filtered_df, test_size=0.75, shuffle=True, stratify=filtered_df["Cat3"], random_state=2
)

In [30]:
print(val.shape)
print(test.shape)

(1982, 3)
(5946, 3)


In [32]:
val.Cat3.value_counts()

Cat3
unknown                   108
shaving hair removal       82
vitamins supplements       61
board games                47
styling tools              45
                         ... 
cakes                       1
basic life skills toys      1
fruit leather               1
joggers                     1
aquarium heaters            1
Name: count, Length: 304, dtype: int64

In [33]:
test.Cat3.value_counts()

Cat3
unknown                 322
shaving hair removal    245
vitamins supplements    183
board games             139
styling tools           135
                       ... 
washcloths towels         1
snack gifts               1
money banks               1
crackers biscuits         1
stimulants                1
Name: count, Length: 320, dtype: int64

In [34]:
def predictions_postprocessing(text: str) -> str:
    """
    Cleans the predicted class label string.

    Args:
        text (str): The predicted class label string.

    Returns:
        str: The cleaned class label string.
    """
    return text.strip().lower()


def evaluate_predictions(
    df: pd.DataFrame,
    target_column: str = "label_text",
    predictions_column: str = "predicted_labels",
    postprocessing: bool = True,
) -> Dict[str, float]:
    """
    Batch evaluation of predictions, returns a dictionary with the metric.

    Args:
       - df (pandas.DataFrame):  a pandas dataframe with two mandatory columns, a target column with
       the actual true values, and a predictions column with the predicted values.
       - target_column (str): column name with the actual ground truth values
       - predictions_column (str): column name with the model predictions
       - postprocessing (bool): whether to apply postprocessing to predictions.

    Returns:
        Dict[str, float]: Dictionary of evaluation metrics.
    """
    if postprocessing:
        df[predictions_column] = df[predictions_column].apply(
            predictions_postprocessing
        )

    y_true = df[target_column]
    y_pred = df[predictions_column]

    metrics_report = classification_report(y_true, y_pred, output_dict=True)
    overall_macro_f1_score = f1_score(y_true, y_pred, average="macro")
    overall_micro_f1_score = f1_score(y_true, y_pred, average="micro")
    weighted_precision = precision_score(y_true, y_pred, average="weighted")
    weighted_recall = recall_score(y_true, y_pred, average="weighted")

    metrics = {
        "accuracy": metrics_report["accuracy"],
        "weighted precision": weighted_precision,
        "weighted recall": weighted_recall,
        "macro f1": overall_macro_f1_score,
        "micro f1": overall_micro_f1_score,
    }

    categories = ["business", "sport", "politics", "tech", "entertainment"]
    for category in categories:
        if category in metrics_report:
            metrics[f"{category}_f1_score"] = metrics_report[category]["f1-score"]

    return metrics

In [35]:
EXPERIMENT_NAME = "sridhanya-classification"  # @param {type:"string"}


In [36]:
experiment_manager = VertexAIExperimentManager(project=PROJECT_ID, location=LOCATION)
experiment_manager.create_experiment(
    experiment_name=EXPERIMENT_NAME,
    experiment_description="Fine-tuning Gemini 1.0 Pro for text classification",
)

In [78]:
# Create an Evaluation dataframe to store the predictions from all the experiments.
df_evals = val[0:50].copy()

In [38]:
# Join the classes into a string
classes_list_str = "\n- ".join(cat3_classes)

In [64]:
# Find the first index for each unique category in 'Cat3'
first_indices = train_df.drop_duplicates(subset='Cat3').index.tolist()


In [65]:
system_prompt_zero_shot = """TASK:
Classify the text into ONLY one of the following classes .

CLASSES:
- {classes_list_str}


INSTRUCTIONS
- Respond with ONLY one class.
- You MUST use the exact word from the list above.
- DO NOT create or use any other classes.
- CAREFULLY analyze the text before choosing the best-fitting category.

"""

In [67]:
system_prompt_few_shot = f"""TASK:
Classify the text into ONLY one of the following classes [business, entertainment, politics, sport, tech].

CLASSES:
- {classes_list_str}

INSTRUCTIONS:
- Respond with ONLY one class.
- You MUST use the exact word from the list above.
- DO NOT create or use any other classes.
- CAREFULLY analyze the text before choosing the best-fitting category.

EXAMPLES:
- EXAMPLE 1:
    <user>
    {train_df.loc[first_indices[0]].Text}
    <model>
    {train_df.loc[first_indices[0]].Cat3}

- EXAMPLE 2:
    <user>
    {train_df.loc[first_indices[1]].Text}
    <model>
    {train_df.loc[first_indices[1]].Cat3}

- EXAMPLE 3:
    <user>
    {train_df.loc[first_indices[2]].Text}
    <model>
    {train_df.loc[first_indices[21]].Cat3}

- EXAMPLE 4:
    <user>
    {train_df.loc[first_indices[3]].Text}
    <model>
    {train_df.loc[first_indices[31]].Cat3}

- EXAMPLE 4:
    <user>
    {train_df.loc[first_indices[4]].Text}
    <model>
    {train_df.loc[first_indices[4]].Cat3}

"""

In [69]:
generation_config = GenerationConfig(max_output_tokens=10, temperature=0)

safety_settings = {
    HarmCategory.HARM_CATEGORY_HATE_SPEECH: HarmBlockThreshold.BLOCK_ONLY_HIGH,
    HarmCategory.HARM_CATEGORY_DANGEROUS_CONTENT: HarmBlockThreshold.BLOCK_ONLY_HIGH,
    HarmCategory.HARM_CATEGORY_SEXUALLY_EXPLICIT: HarmBlockThreshold.BLOCK_ONLY_HIGH,
    HarmCategory.HARM_CATEGORY_HARASSMENT: HarmBlockThreshold.BLOCK_ONLY_HIGH,
}

In [70]:
gem_pro_1_model_zero = GenerativeModel(
    "gemini-1.0-pro-002",  # e.g. gemini-1.5-pro-001, gemini-1.5-flash-001
    system_instruction=[system_prompt_zero_shot],
    generation_config=generation_config,
    safety_settings=safety_settings,
)

In [77]:
# Get the list of messages to predict
messages_to_predict = val["Text"][0:50].to_list()
# Compute the preictions
predictions_zero_shot = batch_predict(
    messages=messages_to_predict, model=gem_pro_1_model_zero, max_workers=4
)

100%|██████████| 50/50 [00:05<00:00,  9.48it/s]


In [79]:
df_evals["gem1.0-zero-shot_predictions"] = predictions_zero_shot
len(predictions_zero_shot)

50

In [81]:
# Compute Evaluation Metrics for zero-shot prompt
metrics_zero_shot = evaluate_predictions(
    df_evals.copy(),
    target_column="Cat3",
    predictions_column="gem1.0-zero-shot_predictions",
    postprocessing=True,
)
metrics_zero_shot

  _warn_prf(average, modifier, f"{metric.capitalize()} is", len(result))
  _warn_prf(average, modifier, f"{metric.capitalize()} is", len(result))
  _warn_prf(average, modifier, f"{metric.capitalize()} is", len(result))
  _warn_prf(average, modifier, f"{metric.capitalize()} is", len(result))
  _warn_prf(average, modifier, f"{metric.capitalize()} is", len(result))
  _warn_prf(average, modifier, f"{metric.capitalize()} is", len(result))
  _warn_prf(average, modifier, f"{metric.capitalize()} is", len(result))
  _warn_prf(average, modifier, f"{metric.capitalize()} is", len(result))


{'accuracy': 0.0,
 'weighted precision': 0.0,
 'weighted recall': 0.0,
 'macro f1': 0.0,
 'micro f1': 0.0}

In [82]:
# Test Few-Shot, and other prompts/possibilities
gem_pro_1_model_few = GenerativeModel(
    "gemini-1.0-pro-002",
    system_instruction=[system_prompt_few_shot],
    generation_config=generation_config,
    safety_settings=safety_settings,
)

In [83]:
predictions_few_shot = batch_predict(
    messages=messages_to_predict, model=gem_pro_1_model_few
)

100%|██████████| 50/50 [00:06<00:00,  7.28it/s]


In [84]:
df_evals["gem1.0-few-shot_predictions"] = predictions_few_shot
len(predictions_few_shot)

50

In [86]:
# Compute Evaluation Metrics for few-shot prompt
metrics_few_shot = evaluate_predictions(
    df_evals.copy(),
    target_column="Cat3",
    predictions_column="gem1.0-few-shot_predictions",
    postprocessing=True,
)
metrics_few_shot

  _warn_prf(average, modifier, f"{metric.capitalize()} is", len(result))
  _warn_prf(average, modifier, f"{metric.capitalize()} is", len(result))
  _warn_prf(average, modifier, f"{metric.capitalize()} is", len(result))
  _warn_prf(average, modifier, f"{metric.capitalize()} is", len(result))
  _warn_prf(average, modifier, f"{metric.capitalize()} is", len(result))
  _warn_prf(average, modifier, f"{metric.capitalize()} is", len(result))
  _warn_prf(average, modifier, f"{metric.capitalize()} is", len(result))
  _warn_prf(average, modifier, f"{metric.capitalize()} is", len(result))


{'accuracy': 0.2,
 'weighted precision': 0.22819047619047617,
 'weighted recall': 0.2,
 'macro f1': 0.12716262975778547,
 'micro f1': 0.2,
 'business_f1_score': 0.0}

In [119]:
tuning_gemini_df = prepare_tuning_dataset_from_df(
    tuning_df=train_df, system_prompt=system_prompt_zero_shot
)

KeyError: 'text'