# Classifying embeddings with Keras and the Gemini API

## Overview

In this notebook, We'll use the embeddings produced by the Gemini API to train a model that can classify newsgroup posts into the categories (the newsgroup itself) from the post contents.

This technique uses the Gemini API's embeddings as input, avoiding the need to train on text input directly, and as a result it is able to perform quite well using relatively few examples compared to training a text model from scratch.

In [23]:
!pip uninstall -qqy jupyterlab kfp 2>/dev/null  # Remove unused conflicting packages
!pip install -U -q "google-genai==1.7.0"

In [24]:
from google import genai
from google.genai import types

genai.__version__

'1.7.0'

### Set up your API key

In [38]:
from kaggle_secrets import UserSecretsClient

GOOGLE_API_KEY = UserSecretsClient().get_secret("GOOGLE_API_KEY")

client = genai.Client(api_key=GOOGLE_API_KEY)

## Dataset

We will use the 20 News Group Dataset from Sklearn.

In [39]:
from sklearn.datasets import fetch_20newsgroups

newsgroups_train = fetch_20newsgroups(subset="train")
newsgroups_test = fetch_20newsgroups(subset="test")

In [40]:
#To see all the attributes
print(dir(newsgroups_train))

['DESCR', 'data', 'filenames', 'target', 'target_names']


In [41]:
# View list of class names for dataset
newsgroups_train.target_names

['alt.atheism',
 'comp.graphics',
 'comp.os.ms-windows.misc',
 'comp.sys.ibm.pc.hardware',
 'comp.sys.mac.hardware',
 'comp.windows.x',
 'misc.forsale',
 'rec.autos',
 'rec.motorcycles',
 'rec.sport.baseball',
 'rec.sport.hockey',
 'sci.crypt',
 'sci.electronics',
 'sci.med',
 'sci.space',
 'soc.religion.christian',
 'talk.politics.guns',
 'talk.politics.mideast',
 'talk.politics.misc',
 'talk.religion.misc']

In [42]:
#Sample Training Data
#print(newsgroups_train.data[0])

In [44]:
## Generally you should preprocess the data by making sure that you remove th esensitive data like email, passwords, names etc.

# import email
# import re

# import pandas as pd


# def preprocess_newsgroup_row(data):
#     # Extract only the subject and body
#     msg = email.message_from_string(data)
#     text = f"{msg['Subject']}\n\n{msg.get_payload()}"
#     # Strip any remaining email addresses
#     text = re.sub(r"[\w\.-]+@[\w\.-]+", "", text)
#     # Truncate each entry to 5,000 characters
#     text = text[:5000]

#     return text


def preprocess_newsgroup_data(newsgroup_dataset):
    # Put data points into dataframe
    df = pd.DataFrame(
        {"Text": newsgroup_dataset.data, "Label": newsgroup_dataset.target}
    )
    # Clean up the text
    #df["Text"] = df["Text"].apply(preprocess_newsgroup_row)
    df["Text"] = df["Text"]
    # Match label to target name index
    df["Class Name"] = df["Label"].map(lambda l: newsgroup_dataset.target_names[l])

    return df

#Apply preprocessing function to training and test datasets
df_train = preprocess_newsgroup_data(newsgroups_train)
df_test = preprocess_newsgroup_data(newsgroups_test)

df_train.head()

Unnamed: 0,Text,Label,Class Name
0,From: lerxst@wam.umd.edu (where's my thing)\nS...,7,rec.autos
1,From: guykuo@carson.u.washington.edu (Guy Kuo)...,4,comp.sys.mac.hardware
2,From: twillis@ec.ecn.purdue.edu (Thomas E Will...,4,comp.sys.mac.hardware
3,From: jgreen@amber (Joe Green)\nSubject: Re: W...,1,comp.graphics
4,From: jcm@head-cfa.harvard.edu (Jonathan McDow...,14,sci.space


Next, We will sample some of the data by taking 100 data points in the training dataset, and dropping a few of the categories to run through to make the process faster. Here, We will choose the science categories to compare.

In [50]:
"""
The .cat.codes attribute in pandas returns the integer codes (numerical representations) for categorical data, 
where each unique category is assigned a consecutive integer starting from 0.
"""

def sample_data(df, num_samples, classes_to_keep):
    # Sample rows, selecting num_samples of each Label.
    df = (
        df.groupby("Label")[df.columns]
        .apply(lambda x: x.sample(num_samples))
        .reset_index(drop=True)
    )

    df = df[df["Class Name"].str.contains(classes_to_keep)]

    # We have fewer categories now, so re-calibrate the label encoding.
    df["Class Name"] = df["Class Name"].astype("category")
    df["Encoded Label"] = df["Class Name"].cat.codes

    return df

In [51]:
TRAIN_NUM_SAMPLES = 100
TEST_NUM_SAMPLES = 25
# Class name should contain 'sci' to keep science categories. sci.med, sci.space etc.
# Try different labels from the data - see newsgroups_train.target_names
CLASSES_TO_KEEP = "sci"

df_train = sample_data(df_train, TRAIN_NUM_SAMPLES, CLASSES_TO_KEEP)
df_test = sample_data(df_test, TEST_NUM_SAMPLES, CLASSES_TO_KEEP)

In [52]:
df_train.value_counts("Class Name")

Class Name
sci.crypt          100
sci.electronics    100
sci.med            100
sci.space          100
Name: count, dtype: int64

In [53]:
df_test.value_counts("Class Name")

Class Name
sci.crypt          25
sci.electronics    25
sci.med            25
sci.space          25
Name: count, dtype: int64

## Create the embeddings

In this section, you will generate embeddings for each piece of text using the Gemini API embeddings endpoint.

### Task types

The `text-embedding-004` model supports a task type parameter that generates embeddings tailored for the specific task.

Task Type | Description
---       | ---
RETRIEVAL_QUERY	| Specifies the given text is a query in a search/retrieval setting.
RETRIEVAL_DOCUMENT | Specifies the given text is a document in a search/retrieval setting.
SEMANTIC_SIMILARITY	| Specifies the given text will be used for Semantic Textual Similarity (STS).
CLASSIFICATION	| Specifies that the embeddings will be used for classification.
CLUSTERING	| Specifies that the embeddings will be used for clustering.
FACT_VERIFICATION | Specifies that the given text will be used for fact verification.

For this example we will be performing classification.

In [55]:
from google.api_core import retry
import tqdm #tqdm adds progress bars.
from tqdm.rich import tqdm as tqdmr
import warnings

# Add tqdm to Pandas...
tqdmr.pandas()

# Define a helper to retry when per-minute quota is reached.
is_retriable = lambda e: (isinstance(e, genai.errors.APIError) and e.code in {429, 503})

@retry.Retry(predicate=is_retriable, timeout=300.0) # @ symbol is called decorator, it is used to give special instruction to the function below it.
def embed_fn(text: str) -> list[float]:
    # Task type is set to classification
    response = client.models.embed_content(
        model="models/text-embedding-004",
        contents=text,
        config=types.EmbedContentConfig(
            task_type="classification",
        ),
    )

    return response.embeddings[0].values


def create_embeddings(df):
    df["Embeddings"] = df["Text"].progress_apply(embed_fn) # Progress apply comes from tqdm.
    return df

This code is optimised for clarity, and is not particularly fast. 

TO DO: Implement [batch](https://ai.google.dev/api/embeddings#method:-models.batchembedcontents) or parallel/asynchronous embedding generation.

In [56]:
df_train = create_embeddings(df_train)
df_test = create_embeddings(df_test)

Output()

  t = cls(total=total, **tqdm_kwargs)


Output()

  t = cls(total=total, **tqdm_kwargs)


In [57]:
df_train.head()

Unnamed: 0,Text,Label,Class Name,Encoded Label,Embeddings
0,From: mjr@tis.com (Marcus J Ranum)\nSubject: R...,11,sci.crypt,0,"[-0.008053452, 0.026679872, -0.03734975, 0.020..."
1,From: strnlght@netcom.com (David Sternlight)\n...,11,sci.crypt,0,"[0.00313395, 0.018756181, -0.036363836, 0.0334..."
2,From: pgf@srl03.cacs.usl.edu (Phil G. Fraering...,11,sci.crypt,0,"[-0.03200049, 0.018352652, -0.03407377, 0.0221..."
3,"From: pla@sktb.demon.co.uk (""Paul L. Allen"")\n...",11,sci.crypt,0,"[0.0078068087, 0.012841915, -0.056115758, 0.01..."
4,"Subject: Re: Don't fight Clipper Chip, subvert...",11,sci.crypt,0,"[0.0017008006, 0.034731373, -0.048679285, 0.00..."


## Build a classification model

Here you will define a simple model that accepts the raw embedding data as input, has two hidden layer, and an output layer specifying the class probabilities. The prediction will correspond to the probability of a piece of text being a particular class of news.

In [58]:
import keras
from keras import layers


def build_classification_model(input_size: int, num_classes: int) -> keras.Model:
    return keras.Sequential(
        [
            layers.Input([input_size], name="embedding_inputs"),
            layers.Dense(input_size, activation="relu", name="hidden_1"), # First hidden layer
            layers.Dense(input_size // 2, activation="relu", name="hidden_2"), # Second hidden layer
            layers.Dense(num_classes, activation="softmax", name="output_probs"),
        ]
    )

In [59]:
# Derive the embedding size from observing the data. The embedding size can also be specified
# with the `output_dimensionality` parameter to `embed_content` if you need to reduce it.
embedding_size = len(df_train["Embeddings"].iloc[0])

classifier = build_classification_model(
    embedding_size, len(df_train["Class Name"].unique())
)
classifier.summary()

classifier.compile(
    loss=keras.losses.SparseCategoricalCrossentropy(),
    optimizer=keras.optimizers.Adam(learning_rate=0.001),
    metrics=["accuracy"],
)

## Train the model

Finally, you can train your model. This code uses early stopping to exit the training loop once the loss value stabilises, so the number of epoch loops executed may differ from the specified value.

In [60]:
import numpy as np


NUM_EPOCHS = 20
BATCH_SIZE = 32

# Split the x and y components of the train and validation subsets.
y_train = df_train["Encoded Label"]
x_train = np.stack(df_train["Embeddings"])
y_val = df_test["Encoded Label"]
x_val = np.stack(df_test["Embeddings"])

# Specify that it's OK to stop early if accuracy stabilises.
early_stop = keras.callbacks.EarlyStopping(monitor="accuracy", patience=3)

# Train the model for the desired number of epochs.
history = classifier.fit(
    x=x_train,
    y=y_train,
    validation_data=(x_val, y_val),
    callbacks=[early_stop],
    batch_size=BATCH_SIZE,
    epochs=NUM_EPOCHS,
)

Epoch 1/20
[1m13/13[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m1s[0m 26ms/step - accuracy: 0.3202 - loss: 1.3668 - val_accuracy: 0.5000 - val_loss: 1.2680
Epoch 2/20
[1m13/13[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 11ms/step - accuracy: 0.6977 - loss: 1.1740 - val_accuracy: 0.6900 - val_loss: 1.0239
Epoch 3/20
[1m13/13[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 10ms/step - accuracy: 0.7626 - loss: 0.8729 - val_accuracy: 0.8600 - val_loss: 0.6930
Epoch 4/20
[1m13/13[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 10ms/step - accuracy: 0.9091 - loss: 0.5095 - val_accuracy: 0.9100 - val_loss: 0.4353
Epoch 5/20
[1m13/13[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 10ms/step - accuracy: 0.9550 - loss: 0.2620 - val_accuracy: 0.9000 - val_loss: 0.3218
Epoch 6/20
[1m13/13[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 10ms/step - accuracy: 0.9623 - loss: 0.1788 - val_accuracy: 0.9000 - val_loss: 0.2910
Epoch 7/20
[1m13/13[0m [32m━━━━

## Evaluate model performance

Use Keras <a href="https://www.tensorflow.org/api_docs/python/tf/keras/Model#evaluate"><code>Model.evaluate</code></a> to calculate the loss and accuracy on the test dataset.

In [61]:
classifier.evaluate(x=x_val, y=y_val, return_dict=True)

[1m4/4[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 3ms/step - accuracy: 0.9281 - loss: 0.2574 


{'accuracy': 0.9399999976158142, 'loss': 0.21013770997524261}

To learn more about training models with Keras, including how to visualise the model training metrics, read [Training & evaluation with built-in methods](https://www.tensorflow.org/guide/keras/training_with_built_in_methods).

## Try a custom prediction

Now that you have a trained model with good evaluation metrics, you can try to make a prediction with new, hand-written data. Use the provided example or try your own data to see how the model performs.

In [62]:
def make_prediction(text: str) -> list[float]:
    """Infer categories from the provided text."""
    # Remember that the model takes embeddings as input, so calculate them first.
    embedded = embed_fn(new_text)

    # And recall that the input must be batched, so here they are wrapped as a
    # list to provide a batch of 1.
    inp = np.array([embedded])

    # And un-batched here.
    [result] = classifier.predict(inp)
    return result

In [63]:
# This example avoids any space-specific terminology to see if the model avoids
# biases towards specific jargon.
new_text = """
First-timer looking to get out of here.

Hi, I'm writing about my interest in travelling to the outer limits!

What kind of craft can I buy? What is easiest to access from this 3rd rock?

Let me know how to do that please.
"""

result = make_prediction(new_text)

for idx, category in enumerate(df_test["Class Name"].cat.categories):
    print(f"{category}: {result[idx] * 100:0.2f}%")

[1m1/1[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 62ms/step
sci.crypt: 0.00%
sci.electronics: 0.11%
sci.med: 0.02%
sci.space: 99.87%
