## Fine Tuning a Custom Model

In [2]:
from google import genai
from google.genai import types

In [3]:
from dotenv import load_dotenv
import os

load_dotenv()

GEMINI_API_KEY = os.getenv("GEMINI_API_KEY")

client = genai.Client(api_key = GEMINI_API_KEY)

In [4]:
for model in client.models.list():
    if model.supported_actions and "createTunedModel" in model.supported_actions:
        print(model.name)


models/gemini-1.5-flash-001-tuning


In [5]:
from sklearn.datasets import fetch_20newsgroups

newsgroups_train = fetch_20newsgroups(subset="train")
newsgroups_test = fetch_20newsgroups(subset="test")

newsgroups_train.target_names

['alt.atheism',
 'comp.graphics',
 'comp.os.ms-windows.misc',
 'comp.sys.ibm.pc.hardware',
 'comp.sys.mac.hardware',
 'comp.windows.x',
 'misc.forsale',
 'rec.autos',
 'rec.motorcycles',
 'rec.sport.baseball',
 'rec.sport.hockey',
 'sci.crypt',
 'sci.electronics',
 'sci.med',
 'sci.space',
 'soc.religion.christian',
 'talk.politics.guns',
 'talk.politics.mideast',
 'talk.politics.misc',
 'talk.religion.misc']

In [6]:
print(newsgroups_train.data[0])

From: lerxst@wam.umd.edu (where's my thing)
Subject: WHAT car is this!?
Nntp-Posting-Host: rac3.wam.umd.edu
Organization: University of Maryland, College Park
Lines: 15

 I was wondering if anyone out there could enlighten me on this car I saw
the other day. It was a 2-door sports car, looked to be from the late 60s/
early 70s. It was called a Bricklin. The doors were really small. In addition,
the front bumper was separate from the rest of the body. This is 
all I know. If anyone can tellme a model name, engine specs, years
of production, where this car is made, history, or whatever info you
have on this funky looking car, please e-mail.

Thanks,
- IL
   ---- brought to you by your neighborhood Lerxst ----







Preparing the data

In [7]:
import email
import re

import pandas as pd


def preprocess_newsgroup_row(data):
    msg = email.message_from_string(data)
    text = f"{msg['Subject']}\n\n{msg.get_payload()}"
    text = re.sub(r"[\w\.-]+@[\w\.-]+", "", text)
    text = text[:40000]

    return text

def preprocess_newsgroup_data(newsgroup_dataset):
    df = pd.DataFrame({
        "Text": newsgroup_dataset.data,
        "Label": newsgroup_dataset.target
    })

    df["Text"] = df["Text"].apply(preprocess_newsgroup_row)
    df["Class Name"] = df['Label'].map(lambda l: newsgroup_dataset.target_names[l])

    return df


In [8]:
df_train = preprocess_newsgroup_data(newsgroups_train)
df_test = preprocess_newsgroup_data(newsgroups_test)

df_train.head()

Unnamed: 0,Text,Label,Class Name
0,WHAT car is this!?\n\n I was wondering if anyo...,7,rec.autos
1,SI Clock Poll - Final Call\n\nA fair number of...,4,comp.sys.mac.hardware
2,"PB questions...\n\nwell folks, my mac plus fin...",4,comp.sys.mac.hardware
3,Re: Weitek P9000 ?\n\nRobert J.C. Kyanko () wr...,1,comp.graphics
4,Re: Shuttle Launch Question\n\nFrom article <>...,14,sci.space


In [9]:
def sample_data(df, num_samples, classes_to_keep):
    df = (df.groupby("Label")[df.columns].apply(lambda x: x.sample(num_samples)).reset_index(drop=True))
    df = df[df["Class Name"].str.contains(classes_to_keep)]
    df["Class Name"] = df["Class Name"].astype("category")

    return df

TRAIN_NUM_SAMPLES = 50
TEST_NUM_SAMPLES = 10
CLASSES_TO_KEEP = "^rec|^sci"

df_train = sample_data(df_train, TRAIN_NUM_SAMPLES, CLASSES_TO_KEEP)
df_test = sample_data(df_test, TEST_NUM_SAMPLES, CLASSES_TO_KEEP)

In [10]:
sample_idx = 0
sample_row = preprocess_newsgroup_row(newsgroups_test.data[sample_idx])
sample_label = newsgroups_test.target_names[newsgroups_test.target[sample_idx]]

print(sample_row)
print('---')
print('Label:', sample_label)

Need info on 88-89 Bonneville


 I am a little confused on all of the models of the 88-89 bonnevilles.
I have heard of the LE SE LSE SSE SSEI. Could someone tell me the
differences are far as features or performance. I am also curious to
know what the book value is for prefereably the 89 model. And how much
less than book value can you usually get them for. In other words how
much are they in demand this time of year. I have heard that the mid-spring
early summer is the best time to buy.

			Neil Gandler

---
Label: rec.autos


In [11]:
response = client.models.generate_content(
    model = "gemini-1.5-flash-001", 
    contents = sample_row
    )

print(response.text)

You're right, the 1988-1989 Bonneville lineup can be a bit confusing! Here's a breakdown of the trims and their differences, plus some info on values:

**1988-1989 Bonneville Trim Levels**

* **Base Bonneville:** This was the entry-level model, offering basic features like a 3.8L V6 engine, cloth upholstery, and minimal options. 
* **LE (Luxury Edition):** The LE added features like a more luxurious interior, power accessories, and a slightly more powerful engine (possibly the 3.8L Turbo V6). 
* **SE (Special Edition):**  The SE was a sportier trim with unique styling elements, usually a slightly firmer suspension, and possibly the Turbo V6.
* **LSE (Luxury Special Edition):**  The LSE combined the luxury features of the LE with the sportiness of the SE, offering a plush interior and a more powerful engine (likely the Turbo V6). 
* **SSE (Sport Sedan Edition):**  The SSE was the top-of-the-line performance model. It came standard with the Turbo V6, a sport-tuned suspension, and additio

In [12]:
# Ask the model directly in a zero-shot prompt.

prompt = "From what newsgroup does the following message originate?"
baseline_response = client.models.generate_content(
    model="gemini-1.5-flash-001",
    contents=[prompt, sample_row])
print(baseline_response.text)

While it's impossible to be 100% sure, this message most likely originates from a **Buick or Pontiac-related newsgroup**, possibly one of the following:

* **alt.autos.buick**
* **rec.autos.buick**
* **alt.autos.pontiac**
* **rec.autos.pontiac**

The message is clearly focused on a specific car model (Bonneville), and the detailed questions about trim levels (LE, SE, LSE, SSE, SSEI) and pricing suggest someone interested in buying a specific vehicle. This makes it highly probable that the message comes from a discussion group dedicated to those brands. 



In [13]:
from google.api_core import retry

# You can use a system instruction to do more direct prompting, and get a
# more succinct answer.

system_instruct = """
You are a classification service. You will be passed input that represents
a newsgroup post and you must respond with the newsgroup from which the post
originates.
"""

# Define a helper to retry when per-minute quota is reached.
is_retriable = lambda e: (isinstance(e, genai.errors.APIError) and e.code in {429, 503})

# If you want to evaluate your own technique, replace this body of this function
# with your model, prompt and other code and return the predicted answer.
@retry.Retry(predicate=is_retriable)
def predict_label(post: str) -> str:
    response = client.models.generate_content(
        model="gemini-1.5-flash-001",
        config=types.GenerateContentConfig(
            system_instruction=system_instruct),
        contents=post)

    rc = response.candidates[0]

    # Any errors, filters, recitation, etc we can mark as a general error
    if rc.finish_reason.name != "STOP":
        return "(error)"
    else:
        # Clean up the response.
        return response.text.strip()


prediction = predict_label(sample_row)

print(prediction)
print()
print("Correct!" if prediction == sample_label else "Incorrect.")

rec.autos.misc

Incorrect.


In [14]:
import tqdm
from tqdm.rich import tqdm as tqdmr
import warnings

# Enable tqdm features on Pandas.
tqdmr.pandas()

# But suppress the experimental warning
warnings.filterwarnings("ignore", category=tqdm.TqdmExperimentalWarning)

# Further sample the test data to be mindful of the free-tier quota.
df_baseline_eval = sample_data(df_test, 2, '.*')

# Make predictions using the sampled data.
df_baseline_eval['Prediction'] = df_baseline_eval['Text'].progress_apply(predict_label)

# And calculate the accuracy.
accuracy = (df_baseline_eval["Class Name"] == df_baseline_eval["Prediction"]).sum() / len(df_baseline_eval)
print(f"Accuracy: {accuracy:.2%}")

Output()

Accuracy: 43.75%


In [15]:
df_baseline_eval

Unnamed: 0,Text,Label,Class Name,Prediction
0,Re: Questions about insurance companies (esp. ...,7,rec.autos,rec.autos
1,Re: V4 V6 V8 V12 Vx?\n\n (The Devil Reincarnat...,7,rec.autos,rec.autos
2,Re: A Point for Helmet Law is a Point for MC B...,8,rec.motorcycles,rec.motorcycles
3,Re: dogs\n\n\n>OOOOOOOpsssss. For a second the...,8,rec.motorcycles,alt.fan.dogs
4,Re: Football vs. BaseBall (was Game Length )\n...,9,rec.sport.baseball,rec.sport.baseball
5,Stats question\n\n\n\tI am just wondering whet...,9,rec.sport.baseball,rec.sports.baseball
6,NHL PLAYOFF RESULTS FOR GAMES PLAYED 4-22-93\n...,10,rec.sport.hockey,rec.sport.hockey
7,Re: Why Blues Pulverized Chicago WeenieHawks\n...,10,rec.sport.hockey,(error)
8,"Re: Overreacting (was Re: Once tapped, your co...",11,sci.crypt,talk.politics.guns
9,Re: Tempest\n\nIn article <>\n (Kirill Shklovs...,11,sci.crypt,(error)


Tune Custom Model

In [16]:
from collections.abc import Iterable 
import random

# Convert the data frame into a dataset suitable for tuning.
input_data = {'examples': 
    df_train[['Text', 'Class Name']]
      .rename(columns={'Text': 'textInput', 'Class Name': 'output'})
      .to_dict(orient='records')
 }

# If you are re-running this lab, add your model_id here.
model_id = None

# Or try and find a recent tuning job.
if not model_id:
  queued_model = None
  # Newest models first.
  for m in reversed(client.tunings.list()):
    # Only look at newsgroup classification models.
    if m.name.startswith('tunedModels/newsgroup-classification-model'):
      # If there is a completed model, use the first (newest) one.
      if m.state.name == 'JOB_STATE_SUCCEEDED':
        model_id = m.name
        print('Found existing tuned model to reuse.')
        break

      elif m.state.name == 'JOB_STATE_RUNNING' and not queued_model:
        # If there's a model still queued, remember the most recent one.
        queued_model = m.name
  else:
    if queued_model:
      model_id = queued_model
      print('Found queued model, still waiting.')


# Upload the training data and queue the tuning job.
if not model_id:
    tuning_op = client.tunings.tune(
        base_model="models/gemini-1.5-flash-001-tuning",
        training_dataset=input_data,
        config=types.CreateTuningJobConfig(
            tuned_model_display_name="Newsgroup classification model",
            batch_size=16,
            epoch_count=2,
        ),
    )

    print(tuning_op.state)
    model_id = tuning_op.name

print(model_id)

Found existing tuned model to reuse.
tunedModels/newsgroup-classification-model-mt0dnzoz9


In [17]:
import datetime
import time

MAX_WAIT = datetime.timedelta(minutes=10)

while not (tuned_model := client.tunings.get(name = model_id)).has_ended:

    print(tuned_model.state)
    time.sleep(60)

    if datetime.datetime.now(datetime.timezone.utc) - tuned_model.create_time > MAX_WAIT:
        print("Taking a shortcut, using a previously prepared model.")
        model_id = "tunedModels/newsgroup-classification-model-ltenbi1b"
        tuned_model = client.tunings.get(name=model_id)
        break

print(f"Done! The model state is: {tuned_model.state.name}")

if not tuned_model.has_succeeded and tuned_model.error:
    print("Error:", tuned_model.error)


Done! The model state is: JOB_STATE_SUCCEEDED


Use the new Model

In [18]:
new_text = """
First-timer looking to get out of here.

Hi, I'm writing about my interest in travelling to the outer limits!

What kind of craft can I buy? What is easiest to access from this 3rd rock?

Let me know how to do that please.
"""

response = client.models.generate_content(
    model=model_id, contents=new_text)

print(response.text)

sci.space


In [19]:
@retry.Retry(predicate=is_retriable)
def classify_text(text: str) -> str:
    """Classify the provided text into a known newsgroup."""
    response = client.models.generate_content(
        model=model_id, contents=text)
    rc = response.candidates[0]

    # Any errors, filters, recitation, etc we can mark as a general error
    if rc.finish_reason.name != "STOP":
        return "(error)"
    else:
        return rc.content.parts[0].text


# The sampling here is just to minimise your quota usage. If you can, you should
# evaluate the whole test set with `df_model_eval = df_test.copy()`.
df_model_eval = sample_data(df_test, 4, '.*')

df_model_eval["Prediction"] = df_model_eval["Text"].progress_apply(classify_text)

accuracy = (df_model_eval["Class Name"] == df_model_eval["Prediction"]).sum() / len(df_model_eval)
print(f"Accuracy: {accuracy:.2%}")

Output()

Accuracy: 87.50%


Token Usage

In [20]:
# Calculate the *input* cost of the baseline model with system instructions.
sysint_tokens = client.models.count_tokens(
    model='gemini-1.5-flash-001', contents=[system_instruct, sample_row]
).total_tokens
print(f'System instructed baseline model: {sysint_tokens} (input)')

# Calculate the input cost of the tuned model.
tuned_tokens = client.models.count_tokens(model=tuned_model.base_model, contents=sample_row).total_tokens
print(f'Tuned model: {tuned_tokens} (input)')

savings = (sysint_tokens - tuned_tokens) / tuned_tokens
print(f'Token savings: {savings:.2%}')  # Note that this is only n=1.

System instructed baseline model: 172 (input)
Tuned model: 136 (input)
Token savings: 26.47%


In [21]:
baseline_token_output = baseline_response.usage_metadata.candidates_token_count
print('Baseline (verbose) output tokens:', baseline_token_output)

tuned_model_output = client.models.generate_content(
    model=model_id, contents=sample_row)
tuned_tokens_output = tuned_model_output.usage_metadata.candidates_token_count
print('Tuned output tokens:', tuned_tokens_output)

Baseline (verbose) output tokens: 141
Tuned output tokens: 4
