# Lesson 2: Embeddings

Note: The numeric values of embeddings you see in your notebook may vary slightly from those filmed.

### Setup
Load needed API keys and relevant Python libaries.

In [1]:
%%capture
!pip install cohere umap-learn altair datasets
!pip install python-dotenv

In [2]:
import os
import getpass
from dotenv import load_dotenv, find_dotenv
_ = load_dotenv(find_dotenv()) # read local .env file

In [3]:
def _set_env(var: str):
    if not os.environ.get(var):
        os.environ[var] = getpass.getpass(f"{var}: ")

In [4]:
_set_env("COHERE_API_KEY")

COHERE_API_KEY: ··········


In [5]:
import cohere
co = cohere.Client(os.environ['COHERE_API_KEY'])

In [7]:
import pandas as pd

## Word Embeddings

Consider a very small dataset of three words.

In [8]:
three_words = pd.DataFrame({'text':
  [
      'joy',
      'happiness',
      'potato'
  ]})

three_words

Unnamed: 0,text
0,joy
1,happiness
2,potato


Let's create the embeddings for the three words:
You may see an 'unknown field' warning which can be ignored.

In [9]:
three_words_emb = co.embed(texts=list(three_words['text']),
                           model='embed-english-v2.0').embeddings

In [10]:
word_1 = three_words_emb[0]
word_2 = three_words_emb[1]
word_3 = three_words_emb[2]

In [11]:
word_1[:10]

[2.3203125,
 -0.18334961,
 -0.578125,
 -0.7314453,
 -2.2050781,
 -2.59375,
 0.35205078,
 -1.6220703,
 0.27954102,
 0.3083496]

## Sentence Embeddings

Consider a very small dataset of three sentences.

In [12]:
sentences = pd.DataFrame({'text':
  [
   'Where is the world cup?',
   'The world cup is in Qatar',
   'What color is the sky?',
   'The sky is blue',
   'Where does the bear live?',
   'The bear lives in the the woods',
   'What is an apple?',
   'An apple is a fruit',
  ]})

sentences

Unnamed: 0,text
0,Where is the world cup?
1,The world cup is in Qatar
2,What color is the sky?
3,The sky is blue
4,Where does the bear live?
5,The bear lives in the the woods
6,What is an apple?
7,An apple is a fruit


Let's create the embeddings for the three sentences:

In [13]:
emb = co.embed(texts=list(sentences['text']),
               model='embed-english-v2.0').embeddings

# Explore the 10 first entries of the embeddings of the 3 sentences:
for e in emb:
    print(e[:3])

[0.27319336, -0.37768555, -1.0273438]
[0.49804688, 1.2236328, 0.4074707]
[-0.23571777, -0.9375, 0.9614258]
[0.08300781, -0.32080078, 0.9272461]
[0.49780273, -0.35058594, -1.6171875]
[1.2294922, -1.3779297, -1.8378906]
[0.15686035, -0.92041016, 1.5996094]
[1.0761719, -0.7211914, 0.9296875]


In [14]:
len(emb[0])

4096

In [None]:
#import umap
#import altair as alt

The next code cell is for hiding some warnings that appear when importing the `umap_plot` library.

In [15]:
# hide the warnings that would appear when importing the UMAP library
from numba.core.errors import NumbaDeprecationWarning, NumbaPendingDeprecationWarning
import warnings
warnings.simplefilter('ignore', category=NumbaDeprecationWarning)
warnings.simplefilter('ignore', category=NumbaPendingDeprecationWarning)

In [26]:
%%writefile utils.py

import umap
import altair as alt

from numba.core.errors import NumbaDeprecationWarning, NumbaPendingDeprecationWarning
import warnings

warnings.simplefilter('ignore', category=NumbaDeprecationWarning)
warnings.simplefilter('ignore', category=NumbaPendingDeprecationWarning)


def umap_plot(text, emb):

    cols = list(text.columns)
    # UMAP reduces the dimensions from 1024 to 2 dimensions that we can plot
    reducer = umap.UMAP(n_neighbors=2)
    umap_embeds = reducer.fit_transform(emb)
    # Prepare the data to plot and interactive visualization
    # using Altair
    #df_explore = pd.DataFrame(data={'text': qa['text']})
    #print(df_explore)

    #df_explore = pd.DataFrame(data={'text': qa_df[0]})
    df_explore = text.copy()
    df_explore['x'] = umap_embeds[:,0]
    df_explore['y'] = umap_embeds[:,1]

    # Plot
    chart = alt.Chart(df_explore).mark_circle(size=60).encode(
        x=#'x',
        alt.X('x',
            scale=alt.Scale(zero=False)
        ),
        y=
        alt.Y('y',
            scale=alt.Scale(zero=False)
        ),
        tooltip=cols
        #tooltip=['text']
    ).properties(
        width=700,
        height=400
    )
    return chart

def umap_plot_big(text, emb):

    cols = list(text.columns)
    # UMAP reduces the dimensions from 1024 to 2 dimensions that we can plot
    reducer = umap.UMAP(n_neighbors=100)
    umap_embeds = reducer.fit_transform(emb)
    # Prepare the data to plot and interactive visualization
    # using Altair
    #df_explore = pd.DataFrame(data={'text': qa['text']})
    #print(df_explore)

    #df_explore = pd.DataFrame(data={'text': qa_df[0]})
    df_explore = text.copy()
    df_explore['x'] = umap_embeds[:,0]
    df_explore['y'] = umap_embeds[:,1]

    # Plot
    chart = alt.Chart(df_explore).mark_circle(size=60).encode(
        x=#'x',
        alt.X('x',
            scale=alt.Scale(zero=False)
        ),
        y=
        alt.Y('y',
            scale=alt.Scale(zero=False)
        ),
        tooltip=cols
        #tooltip=['text']
    ).properties(
        width=700,
        height=400
    )
    return chart

def umap_plot_old(sentences, emb):
    # UMAP reduces the dimensions from 1024 to 2 dimensions that we can plot
    reducer = umap.UMAP(n_neighbors=2)
    umap_embeds = reducer.fit_transform(emb)
    # Prepare the data to plot and interactive visualization
    # using Altair
    #df_explore = pd.DataFrame(data={'text': qa['text']})
    #print(df_explore)

    #df_explore = pd.DataFrame(data={'text': qa_df[0]})
    df_explore = sentences
    df_explore['x'] = umap_embeds[:,0]
    df_explore['y'] = umap_embeds[:,1]

    # Plot
    chart = alt.Chart(df_explore).mark_circle(size=60).encode(
        x=#'x',
        alt.X('x',
            scale=alt.Scale(zero=False)
        ),
        y=
        alt.Y('y',
            scale=alt.Scale(zero=False)
        ),
        tooltip=['text']
    ).properties(
        width=700,
        height=400
    )
    return chart

Overwriting utils.py


In [27]:
from utils import umap_plot

In [28]:
chart = umap_plot(sentences, emb)

  warn(


In [29]:
chart.interactive()

## Articles Embeddings

In [66]:
import pandas as pd
wiki_articles = pd.read_pickle('wikipedia.pkl')
wiki_articles

Unnamed: 0,id,title,text,emb
0,378179,Inca Civil War,"The Inca Civil War, the Inca Dynastic War, or ...","[-1.4873047, -2.1738281, 0.44213867, 1.0185547..."
1,47341,Harm,Harm is physical or psychological or emotional...,"[1.1445312, -0.24572754, 0.06451416, 1.3339844..."
2,115720,Colleges and universities in Puerto Rico,"thumb|Universidad de Puerto Rico, Recinto de R...","[-0.081848145, -1.2001953, -1.7929688, 0.32373..."
3,519703,Mane,"Mane may refer to: *Mane (horse), the line of ...","[-1.234375, 0.67871094, 1.6796875, -0.6513672,..."
4,729568,"Bayfield County, Wisconsin",Bayfield County is a county in the U.S. state ...,"[-1.4804688, -0.045684814, -0.71435547, 1.6015..."
...,...,...,...,...
94,340725,Stella McCartney,"Stella Nina McCartney, (born 13 September 1971...","[0.2956543, 1.0585938, -1.0400391, -0.4152832,..."
95,375561,Höytiäinen,Höytiäinen is a lake in Finland. It covers an ...,"[1.3691406, -0.13745117, -3.8886719, 1.5498047..."
96,74713,Hydrogen atom,A hydrogen atom is an atom of the chemical ele...,"[-0.7314453, 0.46777344, 0.07824707, -1.373046..."
97,413284,Bucheon,Bucheon is a Korean city in Gyeonggi Province ...,"[-1.6640625, 0.7392578, 0.5830078, -1.0380859,..."


In [55]:
import numpy as np
from utils import umap_plot_big

In [67]:
articles = wiki_articles[['title', 'text']]
embeds = np.array([d for d in wiki_articles['emb']])

chart = umap_plot_big(articles, embeds)
chart.interactive()



In [57]:
len(embeds)

15

In [34]:
%%capture
!pip install html2text
!pip install wikitextparser

In [35]:
from threading import Thread
import json
import re
from html2text import html2text as htt
import wikitextparser as wtp


def dewiki(text):
    text = wtp.parse(text).plain_text()  # wiki to plaintext
    text = htt(text)  # remove any HTML
    text = text.replace('\\n',' ')  # replace newlines
    text = re.sub('\s+', ' ', text)  # replace excess whitespace
    return text


def analyze_chunk(text):
    try:
        if '<redirect title="' in text:  # this is not the main article
            return None
        if '(disambiguation)' in text:  # this is not an article
            return None
        else:
            title = text.split('<title>')[1].split('</title>')[0]
            title = htt(title)
            if ':' in title:  # most articles with : in them are not articles we care about
                return None
        serial = text.split('<id>')[1].split('</id>')[0]
        content = text.split('</text')[0].split('<text')[1].split('>', maxsplit=1)[1]
        content = dewiki(content)
        return {'title': title.strip(), 'text': content.strip(), 'id': serial.strip()}
    except Exception as oops:
        print(oops)
        return None


def save_article(article, savedir):
    doc = analyze_chunk(article)
    if doc:
        print('SAVING:', doc['title'])
        filename = doc['id'] + '.json'
        with open(savedir + filename, 'w', encoding='utf-8') as outfile:
            json.dump(doc, outfile, sort_keys=True, indent=1, ensure_ascii=False)


def process_file_text(filename, savedir):
    article = ''
    with open(filename, 'r', encoding='utf-8') as infile:
        for line in infile:
            if '<page>' in line:
                article = ''
            elif '</page>' in line:  # end of article
                Thread(target=save_article, args=(article, savedir)).start()
            else:
                article += line

In [36]:
wiki_xml_file = '/content/simplewiki-20240920-pages-articles-multistream.xml'  # update this
json_save_dir = './wiki_plaintext/'

In [37]:
process_file_text(wiki_xml_file, json_save_dir)

[1;30;43mStreaming output truncated to the last 5000 lines.[0m
SAVING: SAVING: Richard Osman
SAVING: Angelica Ortiz
Paola Medina
SAVING: The Sure Thing
SAVING: Adriana Vinas JoySAVING: Stratopause

SAVING: Ana Sofía Jusino
SAVING: Jonathan Dowling
SAVING: Samuel J. Randall
SAVING: Alba Hernández
SAVING: Rosita Fornés
SAVING: Armando Silvestre
SAVING: María Victoria
SAVING: Gloria Mange
SAVING: Ciudad Obregón
SAVING: Rosa de Castilla
SAVING: Neira Ortiz Ruiz
SAVING: Ed Ames
SAVING: Warwick, New York
SAVING: Florida, Orange County, New York
SAVING: Middletown, Orange County, New York
SAVING: Port Jervis, New York
SAVING: Matamoras, Pennsylvania
SAVING: Montague Township, New Jersey
SAVING: Greenville, Orange County, New York
SAVING: Deerpark, New York
SAVING: Middletown, Delaware County, New York
SAVING: Middletown, New York
SAVING: Florida, Montgomery County, New York
SAVING: Cohoes, New York
SAVING: Mohawk River
SAVING: Mohawk people
SAVING: Navarre, Florida
SAVING: New York, Florida

In [46]:
import pandas as pd
import cohere

def add_embeddings_column(
    df: pd.DataFrame,
    cohere_client: cohere.Client,
    text_col: str = "text",
    emb_col: str = "emb",
    model_name: str = "embed-english-v2.0"
) -> pd.DataFrame:
    """
    Add an embeddings column to a pandas DataFrame using Cohere's embedding API.

    Parameters
    ----------
    df : pd.DataFrame
        The DataFrame to which an embeddings column will be added.
    cohere_client : cohere.Client
        An authenticated Cohere client instance.
    text_col : str, optional
        The column name in df that contains the text to embed. Default is "text".
    emb_col : str, optional
        The name of the new column to store embeddings. Default is "emb".
    model_name : str, optional
        The Cohere embedding model to use. Default is "embed-english-v2.0".

    Returns
    -------
    pd.DataFrame
        The same DataFrame with an additional column for embeddings.
    """
    # Extract the text data as a list
    texts = df[text_col].tolist()

    # Call Cohere to get embeddings
    response = cohere_client.embed(
        texts=texts,
        model=model_name
    )

    # response.embeddings is a list of embedding vectors (lists of floats)
    df[emb_col] = response.embeddings

    return df


def get_embedding(text, cohere_client, model_name="embed-english-v2.0"):
    """
    For a single string `text`, call Cohere to get the embedding.
    Returns a single embedding vector (list of floats).
    """
    response = cohere_client.embed(texts=[text], model=model_name)
    return response.embeddings[0]

In [61]:
import os
import json
import pandas as pd

def json_dir_to_pickle(json_dir: str, output_pickle_file: str) -> pd.DataFrame:
    """
    Traverse a directory of JSON files and create a single pandas DataFrame,
    then save the DataFrame to a pickle file.

    Each JSON file is expected to have the following keys:
        - "id"
        - "title"
        - "text"

    Parameters
    ----------
    json_dir : str
        Path to the directory containing the JSON files.
    output_pickle_file : str
        The filename (or path) for the resulting pickle file.

    Returns
    -------
    pd.DataFrame
        A pandas DataFrame containing concatenated data from all JSON files.

    Example
    -------
    >>> df = json_dir_to_pickle("path_to_json_dir", "output.pkl")
    >>> df.head()
    """
    data_rows = []

    # Traverse the specified directory
    for file_name in os.listdir(json_dir):
        # Only process JSON files
        if file_name.endswith(".json"):
            file_path = os.path.join(json_dir, file_name)

            # Open and load the JSON
            with open(file_path, "r", encoding="utf-8") as f:
                record = json.load(f)

                # Append the relevant fields to data_rows
                data_rows.append({
                    "id": record.get("id"),
                    "title": record.get("title"),
                    "text": record.get("text")
                })

    # Convert the collected records into a DataFrame
    df = pd.DataFrame(data_rows)
    # df['emb'] = df['text'].apply(lambda x: get_embedding(x, co))

    # Save the DataFrame to a pickle file
    df.to_pickle(output_pickle_file)

    return df


In [62]:
# Specify the directory containing JSON files
json_directory_path = "/content/wiki_plaintext"

# Specify the output pickle file name or path
pickle_output = "wikipedia.pkl"

# Call the function
df_result = json_dir_to_pickle(json_directory_path, pickle_output)

In [63]:
df_result = df_result.iloc[:99]
df_result['emb'] = df_result['text'].apply(lambda x: get_embedding(x, co))
df_result

Unnamed: 0,id,title,text,emb
0,378179,Inca Civil War,"The Inca Civil War, the Inca Dynastic War, or ...","[-1.4873047, -2.1738281, 0.44213867, 1.0185547..."
1,47341,Harm,Harm is physical or psychological or emotional...,"[1.1445312, -0.24572754, 0.06451416, 1.3339844..."
2,115720,Colleges and universities in Puerto Rico,"thumb|Universidad de Puerto Rico, Recinto de R...","[-0.081848145, -1.2001953, -1.7929688, 0.32373..."
3,519703,Mane,"Mane may refer to: *Mane (horse), the line of ...","[-1.234375, 0.67871094, 1.6796875, -0.6513672,..."
4,729568,"Bayfield County, Wisconsin",Bayfield County is a county in the U.S. state ...,"[-1.4804688, -0.045684814, -0.71435547, 1.6015..."
...,...,...,...,...
94,340725,Stella McCartney,"Stella Nina McCartney, (born 13 September 1971...","[0.2956543, 1.0585938, -1.0400391, -0.4152832,..."
95,375561,Höytiäinen,Höytiäinen is a lake in Finland. It covers an ...,"[1.3691406, -0.13745117, -3.8886719, 1.5498047..."
96,74713,Hydrogen atom,A hydrogen atom is an atom of the chemical ele...,"[-0.7314453, 0.46777344, 0.07824707, -1.373046..."
97,413284,Bucheon,Bucheon is a Korean city in Gyeonggi Province ...,"[-1.6640625, 0.7392578, 0.5830078, -1.0380859,..."


In [64]:
df_result.shape

(99, 4)

In [65]:
df_result.to_pickle(pickle_output)