# Classifying embeddings with keras and the Gemini API

- To use the embeddings produced by the Gemini API to train a model that can classify news group posts into the categories(the newsgroup itself) from the post contents.

- This techniques uses the Gemini API's embeddings as input, avoiding the nedd to train on text input directly, and as result it is able to perform quite well using relatively few examples compared to training to text model from scratch

In [1]:
import google.generativeai as genai
from IPython.display import Markdown
import os

genai.configure(api_key=os.environ["GOOGLE_API_KEY"])


## Dataset

The 20 Newsgroups Text Dataset contains 18,000 newsgroups posts on 20 topics divided into training and test sets. The split between the training and test datasets are based on messages posted before and after a specific date. For this tutorial, you will use sampled subsets of the training and test sets, and perform some processing using Pandas.

In [5]:
from sklearn.datasets import fetch_20newsgroups

newsgroups_train = fetch_20newsgroups(subset="train")
newsgroups_test = fetch_20newsgroups(subset="test")

# View list of class names for dataset
newsgroups_train.target_names

['alt.atheism',
 'comp.graphics',
 'comp.os.ms-windows.misc',
 'comp.sys.ibm.pc.hardware',
 'comp.sys.mac.hardware',
 'comp.windows.x',
 'misc.forsale',
 'rec.autos',
 'rec.motorcycles',
 'rec.sport.baseball',
 'rec.sport.hockey',
 'sci.crypt',
 'sci.electronics',
 'sci.med',
 'sci.space',
 'soc.religion.christian',
 'talk.politics.guns',
 'talk.politics.mideast',
 'talk.politics.misc',
 'talk.religion.misc']

Here is an example of what a record from the training set looks like.

In [6]:
print(newsgroups_train.data[0])

From: lerxst@wam.umd.edu (where's my thing)
Subject: WHAT car is this!?
Nntp-Posting-Host: rac3.wam.umd.edu
Organization: University of Maryland, College Park
Lines: 15

 I was wondering if anyone out there could enlighten me on this car I saw
the other day. It was a 2-door sports car, looked to be from the late 60s/
early 70s. It was called a Bricklin. The doors were really small. In addition,
the front bumper was separate from the rest of the body. This is 
all I know. If anyone can tellme a model name, engine specs, years
of production, where this car is made, history, or whatever info you
have on this funky looking car, please e-mail.

Thanks,
- IL
   ---- brought to you by your neighborhood Lerxst ----







To remove any sensitive information like names and email addresses, you will take only the subject and body of each message. This is an optional step that transforms the input data into more generic text, rather than email posts, so that it will work in other contexts.

In [8]:
import email
import re

import pandas as pd

def preprocess_newsgroup_row(data):
    # Extract only the subject and body
    msg = email.message_from_string(data)
    text = f"{msg['Subject']}\n\n{msg.get_payload()}"
    # Strip any remaining email addresses
    text = re.sub(r"[\w\.-]+@[\w\.-]+", "", text)
    # Truncate each entry to 5,000 characters
    text = text[:5000]

    return text

def preprocess_newsgroup_data(newsgroup_dataset):
    # Put data points into dataframe
    df = pd.DataFrame(
        {"Text": newsgroup_dataset.data, "Label": newsgroup_dataset.target}
    )
    # Clean up the text
    df["Text"] = df["Text"].apply(preprocess_newsgroup_row)
    # Match label to target name index
    df["Class Name"] = df["Label"].map(lambda l: newsgroup_dataset.target_names[l])

    return df


In [9]:
# Apply preprocessing function to training and test datasets
df_train = preprocess_newsgroup_data(newsgroups_train)
df_test = preprocess_newsgroup_data(newsgroups_test)

df_train.head()

Unnamed: 0,Text,Label,Class Name
0,WHAT car is this!?\n\n I was wondering if anyo...,7,rec.autos
1,SI Clock Poll - Final Call\n\nA fair number of...,4,comp.sys.mac.hardware
2,"PB questions...\n\nwell folks, my mac plus fin...",4,comp.sys.mac.hardware
3,Re: Weitek P9000 ?\n\nRobert J.C. Kyanko () wr...,1,comp.graphics
4,Re: Shuttle Launch Question\n\nFrom article <>...,14,sci.space


- Next, will sample some of the data by taking 100 data points in the training dataset, and dropping a few of the categories.
- Choose the science categories to compare.

In [10]:
def sample_data(df, num_samples, classes_to_keep):
    # Sample rows, selecting num_samples of each Label.
    df = (
        df.groupby("Label")[df.columns]
        .apply(lambda x: x.sample(num_samples))
        .reset_index(drop=True)
    )

    df = df[df["Class Name"].str.contains(classes_to_keep)]

    # We have fewer categories now, so re-calibrate the label encoding.
    df["Class Name"] = df["Class Name"].astype("category")
    df["Encoded Label"] = df["Class Name"].cat.codes

    return df

In [11]:
TRAIN_NUM_SAMPLES = 100
TEST_NUM_SAMPLES = 25
CLASSES_TO_KEEP = "sci"  # Class name should contain 'sci' to keep science categories

df_train = sample_data(df_train, TRAIN_NUM_SAMPLES, CLASSES_TO_KEEP)
df_test = sample_data(df_test, TEST_NUM_SAMPLES, CLASSES_TO_KEEP)

In [12]:
df_train.value_counts("Class Name")

Class Name
sci.crypt          100
sci.electronics    100
sci.med            100
sci.space          100
Name: count, dtype: int64

In [13]:
df_test.value_counts("Class Name")

Class Name
sci.crypt          25
sci.electronics    25
sci.med            25
sci.space          25
Name: count, dtype: int64

## Create the embeddings

- generate embeddings for each piece of text using the Gemini API embeddings endpoint.

<div class="alert alert-block alert-info">
<b>NOTE:</b> Embeddings are computed one at a time, so large sample sizes can take a long time!
</div>

**TASK types**

The `text-embedding-004` model supports a task type parameter that generates embeddings tailored for the specific task.

| Task Type | Description |
|---|---|
| RETRIEVAL_QUERY | Specifies the given text is a query in a search/retrieval setting. |
| RETRIEVAL_DOCUMENT | Specifies the given text is a document in a search/retrieval setting. |
| SEMANTIC_SIMILARITY | Specifies the given text will be used for Semantic Textual Similarity (STS). |
| CLASSIFICATION | Specifies that the embeddings will be used for classification. |
| CLUSTERING | Specifies that the embeddings will be used for clustering. |
| FACT_VERIFICATION | Specifies that the given text will be used for fact verification. |

For this example we will be performing classification

In [16]:
from tqdm.auto import tqdm

tqdm.pandas()

from google.api_core import retry

@retry.Retry(timeout=300.0)
def embed_fn(text: str) -> list[float]:
    # You will be performing classification, so set task_type accordingly.
    response = genai.embed_content(
        model="models/text-embedding-004", content=text, task_type="classification"
    )

    return response["embedding"]


def create_embeddings(df):
    df["Embeddings"] = df["Text"].progress_apply(embed_fn)
    return df

To implement batch or parallel/asynchronous embedding generation. 

In [17]:
df_train = create_embeddings(df_train)
df_test = create_embeddings(df_test)

  0%|          | 0/400 [00:00<?, ?it/s]

  0%|          | 0/100 [00:00<?, ?it/s]

In [18]:
df_train.head()

Unnamed: 0,Text,Label,Class Name,Encoded Label,Embeddings
1100,"Re: Illegal Wiretaps (was\n\nIn article <>, di...",11,sci.crypt,0,"[-0.0010601141, 0.022317966, -0.03246477, 0.02..."
1101,Re: Fighting the Clipper Initiative\n\n (Phili...,11,sci.crypt,0,"[0.005631752, 0.019081194, -0.041051492, 0.012..."
1102,"Re: Once tapped, your code is no good any more...",11,sci.crypt,0,"[-0.025410943, 0.013192673, -0.053671326, -0.0..."
1103,"More Clipper Stuff\n\nAs of yet, there has bee...",11,sci.crypt,0,"[0.0063751615, 0.03159215, -0.059843536, 0.022..."
1104,"Re: Once tapped, your code is no good any more...",11,sci.crypt,0,"[-0.0067106336, 0.0061708996, -0.029784752, 0...."


## Build a simple classification model

- Define a simple model that accepts the raw embedding data as input, has one hidden layer, and an output layer specifying the class probabilities.
- The prediction will correspond to the probability of a piece of text being a particular class of news.

When run the model, Keras will take care of details like shuffling the data points, calculating metrics and other ML boilerplate.
