# Fine tuning a custom model

- To fine-tune a custom, task-specific model. Fine-tuning can be used for a variety of tasks from classic NLP problems like entity extraction or summarisation, to creative tasks like stylised generation. 

- fine-tune a model to classify the category a piece of text (a newsgroup post) into the category it belongs to (the newsgroup name).

In [1]:
import google.generativeai as genai
from IPython.display import HTML, Markdown, display
import os

genai.configure(api_key=os.environ["GOOGLE_API_KEY"])

**Explore available models**

-  `TunedModel.create` API method to start the fine-tuning job and create your custom model.
-  `model.list` Find a model that supports it through the models.list endpoint

In [2]:
for model in genai.list_models():
    if "createTunedModel" in model.supported_generation_methods:
        print(model.name)

models/gemini-1.0-pro-001
models/gemini-1.5-flash-001-tuning


## Download the Dataset

In [3]:
from sklearn.datasets import fetch_20newsgroups

newsgroups_train = fetch_20newsgroups(subset="train")
newsgroups_test = fetch_20newsgroups(subset="test")

# View list of class names for dataset
newsgroups_train.target_names

['alt.atheism',
 'comp.graphics',
 'comp.os.ms-windows.misc',
 'comp.sys.ibm.pc.hardware',
 'comp.sys.mac.hardware',
 'comp.windows.x',
 'misc.forsale',
 'rec.autos',
 'rec.motorcycles',
 'rec.sport.baseball',
 'rec.sport.hockey',
 'sci.crypt',
 'sci.electronics',
 'sci.med',
 'sci.space',
 'soc.religion.christian',
 'talk.politics.guns',
 'talk.politics.mideast',
 'talk.politics.misc',
 'talk.religion.misc']

In [4]:
print(newsgroups_train.data[0])

From: lerxst@wam.umd.edu (where's my thing)
Subject: WHAT car is this!?
Nntp-Posting-Host: rac3.wam.umd.edu
Organization: University of Maryland, College Park
Lines: 15

 I was wondering if anyone out there could enlighten me on this car I saw
the other day. It was a 2-door sports car, looked to be from the late 60s/
early 70s. It was called a Bricklin. The doors were really small. In addition,
the front bumper was separate from the rest of the body. This is 
all I know. If anyone can tellme a model name, engine specs, years
of production, where this car is made, history, or whatever info you
have on this funky looking car, please e-mail.

Thanks,
- IL
   ---- brought to you by your neighborhood Lerxst ----







## Prepare the dataset

- Removes personal information, which can be used to "shortcut" to known users of a forum, and formats the text to appear a bit more like regular text and less like a newsgroup post (e.g. by removing the mail headers).
- This normalisation allows the model to generalise to regular text and not over-depend on specific fields. If your input data is always going to be newsgroup posts, it may be helpful to leave this structure in place if they provide genuine signals.

In [5]:
import email
import re

import pandas as pd


def preprocess_newsgroup_row(data):
    # Extract only the subject and body
    msg = email.message_from_string(data)
    text = f"{msg['Subject']}\n\n{msg.get_payload()}"
    # Strip any remaining email addresses
    text = re.sub(r"[\w\.-]+@[\w\.-]+", "", text)
    # Truncate the text to fit within the input limits
    text = text[:40000]

    return text


def preprocess_newsgroup_data(newsgroup_dataset):
    # Put data points into dataframe
    df = pd.DataFrame(
        {"Text": newsgroup_dataset.data, "Label": newsgroup_dataset.target}
    )
    # Clean up the text
    df["Text"] = df["Text"].apply(preprocess_newsgroup_row)
    # Match label to target name index
    df["Class Name"] = df["Label"].map(lambda l: newsgroup_dataset.target_names[l])

    return df

In [6]:
# Apply preprocessing to training and test datasets
df_train = preprocess_newsgroup_data(newsgroups_train)
df_test = preprocess_newsgroup_data(newsgroups_test)

df_train.head()

Unnamed: 0,Text,Label,Class Name
0,WHAT car is this!?\n\n I was wondering if anyo...,7,rec.autos
1,SI Clock Poll - Final Call\n\nA fair number of...,4,comp.sys.mac.hardware
2,"PB questions...\n\nwell folks, my mac plus fin...",4,comp.sys.mac.hardware
3,Re: Weitek P9000 ?\n\nRobert J.C. Kyanko () wr...,1,comp.graphics
4,Re: Shuttle Launch Question\n\nFrom article <>...,14,sci.space


Now sample the data. You will keep 50 rows for each category for training.

As this technique (parameter-efficient fine-tuning, or PEFT) updates a relatively small number of parameters and does not require training a new model or updating the large model.

In [7]:
def sample_data(df, num_samples, classes_to_keep):
    # Sample rows, selecting num_samples of each Label.
    df = (
        df.groupby("Label")[df.columns]
        .apply(lambda x: x.sample(num_samples))
        .reset_index(drop=True)
    )

    df = df[df["Class Name"].str.contains(classes_to_keep)]
    df["Class Name"] = df["Class Name"].astype("category")

    return df


TRAIN_NUM_SAMPLES = 50
TEST_NUM_SAMPLES = 10
# Keep rec.* and sci.*
CLASSES_TO_KEEP = "^rec|^sci"

df_train = sample_data(df_train, TRAIN_NUM_SAMPLES, CLASSES_TO_KEEP)
df_test = sample_data(df_test, TEST_NUM_SAMPLES, CLASSES_TO_KEEP)

## Evaluate baseline performance

Before you start tuning a model, it's good practice to perform an evaluation on the available models to ensure you can measure how much the tuning helps.

First identify a single sample row to use for visual inspection.


In [8]:
sample_idx = 0
sample_row = preprocess_newsgroup_row(newsgroups_test.data[sample_idx])
sample_label = newsgroups_test.target_names[newsgroups_test.target[sample_idx]]

print(sample_row)
print('---')
print('Label:', sample_label)

Need info on 88-89 Bonneville


 I am a little confused on all of the models of the 88-89 bonnevilles.
I have heard of the LE SE LSE SSE SSEI. Could someone tell me the
differences are far as features or performance. I am also curious to
know what the book value is for prefereably the 89 model. And how much
less than book value can you usually get them for. In other words how
much are they in demand this time of year. I have heard that the mid-spring
early summer is the best time to buy.

			Neil Gandler

---
Label: rec.autos


Passing the text directly in as a prompt does not yield the desired results. The model will attempt to respond to the message.

In [9]:
baseline_model = genai.GenerativeModel("gemini-1.5-flash-001")
response = baseline_model.generate_content(sample_row)
print(response.text)

You're right, the Bonneville model lineup for 1988-89 was a bit confusing! Here's a breakdown to clear things up:

**1988-1989 Bonneville Models:**

* **Base Bonneville:** The standard model, typically with a 3.8L V6 engine. It offered basic features and was the most affordable option.
* **LE (Luxury Edition):**  This trim added more comfort features like plusher upholstery, power accessories, and often a more upscale interior design.
* **SE (Special Edition):**  This trim focused on aesthetics with a sportier look, including unique exterior and interior styling elements, and sometimes a slightly more powerful engine option. 
* **LSE (Luxury Special Edition):** This trim combined features from both the LE and SE, providing a luxurious and sporty experience.  
* **SSE (Sport Sedan Edition):** This model was the performance-focused option. It came with a larger 3.8L V6 engine with more horsepower and torque, a sportier suspension, and aggressive styling elements. 
* **SSEi (Sport Sedan E

- Use the prompt engineering techniques you have learned this week to induce the model to perform the desired task. 

Try some of your own ideas and see what is effective, or check out the following cells for different approaches. Note that they have different levels of effectiveness!

In [11]:
# Ask the model directly in a zero-shot prompt.

prompt = "From what newsgroup does the following message originate?"
baseline_response = baseline_model.generate_content([prompt, sample_row])
print(baseline_response.text)

This message likely originates from a newsgroup focused on **Buick vehicles**, specifically the **Bonneville model**. 

Here's why:

* **Topic:** The message explicitly mentions "88-89 Bonnevilles" and inquires about different trim levels (LE, SE, LSE, SSE, SSEi). This indicates a strong focus on the Buick Bonneville.
* **Content:** The message asks for information about features, performance, and value, typical questions asked in car-related forums.
* **Common Newsgroup:**  A newsgroup like **alt.autos.buick** or **rec.autos.buick** would be a likely source, given its dedicated focus on Buick vehicles. 

However, without seeing the actual newsgroup header, it's impossible to know for sure. 



That technique produces quite a verbose response. You could try and pick out the relevant text, or refine the prompt even further.

In [12]:
from google.api_core import retry

# You can use a system instruction to do more direct prompting, and get a
# more succinct answer.

system_instruct = """
You are a classification service. You will be passed input that represents
a newsgroup post and you must respond with the newsgroup from which the post
originates.
"""

instructed_model = genai.GenerativeModel("gemini-1.5-flash-001",
                                         system_instruction=system_instruct)

retry_policy = {"retry": retry.Retry(predicate=retry.if_transient_error)}

# If you want to evaluate your own technique, replace this function with your
# model, prompt and other code and return the predicted answer.
def predict_label(post: str) -> str:
    response = instructed_model.generate_content(post, request_options=retry_policy)
    rc = response.candidates[0]

    # Any errors, filters, recitation, etc we can mark as a general error
    if rc.finish_reason.name != "STOP":
        return "(error)"
    else:
        # Clean up the response.
        return response.text.strip()


prediction = predict_label(sample_row)

print(prediction)
print()
print("Correct!" if prediction == sample_label else "Incorrect.")

rec.autos.misc

Incorrect.


Now run a short evaluation using the function defined above. The test set is further sampled to ensure the experiment runs smoothly on the API's free tier. In practice you would evaluate over the whole set.

In [13]:
from tqdm.rich import tqdm

tqdm.pandas()


# Further sample the test data to be mindful of the free-tier quota.
df_baseline_eval = sample_data(df_test, 2, '.*')

# Make predictions using the sampled data.
df_baseline_eval['Prediction'] = df_baseline_eval['Text'].progress_apply(predict_label)

# And calculate the accuracy.
accuracy = (df_baseline_eval["Class Name"] == df_baseline_eval["Prediction"]).sum() / len(df_baseline_eval)
print(f"Accuracy: {accuracy:.2%}")

Output()

  t = cls(total=total, **tqdm_kwargs)


Accuracy: 25.00%


Now take a look at the dataframe to compare the predictions with the labels

In [14]:
df_baseline_eval

Unnamed: 0,Text,Label,Class Name,Prediction
0,Re: V4 V6 V8 V12 Vx?\n\nIn article <>\n (The D...,7,rec.autos,rec.autos.misc
1,Re: WARNING.....(please read)...\n\nIn article...,7,rec.autos,alt.politics.abortion
2,"Back Breaker, Near Hit!!\n\nAbout a year and h...",8,rec.motorcycles,rec.motorcycles
3,Re: Good Reasons to Wave at each other\n\n (Mi...,8,rec.motorcycles,rec.motorcycles
4,Kevin Mitchell Does It Again\n\nIn what seems ...,9,rec.sport.baseball,rec.sport.baseball
5,Re: Game Length (was Re: Braves Update!!\n\nIn...,9,rec.sport.baseball,rec.sport.baseball
6,Re: And you think ESPN shafted you?\n\nIn arti...,10,rec.sport.hockey,rec.sports.hockey
7,Re: Minnesota Shame?\n\n>I've been under the i...,10,rec.sport.hockey,rec.sports.hockey
8,Re: The [secret] source of that announcement\n...,11,sci.crypt,(error)
9,Re: Clipper and conference calls\n\n\n|> > ...,11,sci.crypt,comp.security.unix


## Tune a custom model

In this example you'll use tuning to help create a model that requires no prompting or system instructions and outputs succinct text from the classes you provide in the training data.

The data contains both input text (the processed posts) and output text (the category, or newsgroup), which you can use to start tuning a model.

The Python SDK for tuning supports Pandas dataframes as input, so you don't need any custom data generators or pipelines. Just specify the input and the relevant columns as the `input_key` and `output_key.`

When calling `create_tuned_model,` you can specify model tuning hyperparameters too:

- `epoch_count:` defines how many times to loop through the data,
- `batch_size:` defines how many rows to process in a single step, and
- `learning_rate:` defines the scaling factor for updating model weights at each step.

In [15]:
from collections.abc import Iterable
import random


# Append a random number to the model ID so you can re-run with a higher chance
# of creating a unique model ID.
model_id = f"newsgroup-classifier-{random.randint(10000, 99999)}"

# Upload the training data and queue the tuning job.
tuning_op = genai.create_tuned_model(
    "models/gemini-1.5-flash-001-tuning",
    training_data=df_train,
    input_key="Text",  # the column to use as input
    output_key="Class Name",  # the column to use as output
    id=model_id,
    display_name="Newsgroup classification model",
    batch_size=16,
    epoch_count=2,
)

print(model_id)

newsgroup-classifier-11542


This has created a tuning job that will run in the background. To inspect the progress of the tuning job, run this cell to plot the current status and loss curve. Once the status reaches `ACTIVE,` tuning is complete and the model is ready to use.

Tuning jobs are queued, so it may look like no training steps have been taken initially but it will progress. Tuning can take upwards of 20 minutes, depending on factors like your dataset size and how busy the tuning infrastrature is.

It is safe to stop this cell at any point. It will not stop the tuning job.

In [None]:
import time
import seaborn as sns


while (tuned_model := genai.get_tuned_model(f"tunedModels/{model_id}")).state.name != 'ACTIVE':

    print(tuned_model.state)
    time.sleep(60)

print(f"Done! The model is {tuned_model.state.name}")

# Plot the loss curve.
snapshots = pd.DataFrame(tuned_model.tuning_task.snapshots)
sns.lineplot(data=snapshots, x="step", y="mean_loss")

1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
