<a href="https://colab.research.google.com/github/RomanEngeler1805/cohere-hackathon-Sep22/blob/main/Cohere_Embed_Analyse.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

# Cohere Analyse
This script has been developed during the *Cohere AI Hackathon #2* to make the Cohere embed endpoints better accessible. <br><br>

It runs and displays all outputs within the script. The purpose is to try things out with low latency.
<br><br>

It consists of four parts:
<li>
Data upload
<li>
Exploratory data analysis (EDA)
<li>
Cluster analysis
<li>
Semantic search

<br>

TODO: search for TODO keywords and insert valid cohere api key.

<br>
Make sure to run all cells sequentially and not skip any section<br>
A CPU is enough to run it as all the heavy lifting is done on Cohere's side


In [7]:
# cohere api key 
api_key = None # TODO add valid api key
model_size = 'medium' # start with 'small' or 'medium' for speed.

## Installs & Imports

In [8]:
!pip install cohere umap-learn altair annoy datasets bertopic transformers datasets streamlit pyngrok -q 

In [9]:
import cohere
import numpy as np
import pandas as pd
from datasets import load_dataset
import umap
import altair as alt
from annoy import AnnoyIndex
import warnings
from sklearn.cluster import KMeans
from bertopic._ctfidf import ClassTFIDF
from sklearn.feature_extraction.text import CountVectorizer
from typing import Tuple

warnings.filterwarnings('ignore')
pd.set_option('display.max_colwidth', None)

## Data Preparation

### Helper Functions

In [10]:
# Load & Prepare dataset
def get_dataset(df: pd.DataFrame, text: str, title: str, max_length: int=100) -> pd.DataFrame:
  '''
  inputs:
  - df: dataframe of data
  - text: name of text column
  - title: name of title column
  - max_length: parameter to limit length for the sake of speed

  outputs:
  - df: dataframe with 'text' and 'title' column
  '''
  df.rename(columns={text: 'text', title: 'title'}, inplace=True)
  df = df[['title', 'text']]
  max_length = min(max_length, df.shape[0])
  df = df.head(max_length)
  return df

In [11]:
# for the amazon dataset
def filter_dataset(row):
  return row['language'] == 'en'

### Data Upload

In [27]:
# get the amazon dataset
a = load_dataset('amazon_reviews_multi')
a = a.filter(filter_dataset)
a.set_format('pandas')
df = a['train'][:]

df = df.sample(frac=1)
df.to_csv('amazon_reviews')

# format the dataset
df = get_dataset(df, text='review_body', title='review_title')



  0%|          | 0/3 [00:00<?, ?it/s]



## Exploratory Data Analysis

In [13]:
# inspect the dataset
print('Data columns:')
all_columns = df.columns.to_list()
print(all_columns)
print('\n')

print('Number of lines')
print(f'The number of lines is {df.shape[0]}')
print('\n')

print('Missing values')
print(df.isnull().sum())

Data columns:
['title', 'text']


Number of lines
The number of lines is 100


Missing values
title    0
text     0
dtype: int64


In [14]:
# inspect the dataset
df.head(5)

Unnamed: 0,title,text
0,I'll spend twice the amount of time boxing up the whole useless thing and send it back with a 1-star review ...,"Arrived broken. Manufacturer defect. Two of the legs of the base were not completely formed, so there was no way to insert the casters. I unpackaged the entire chair and hardware before noticing this. So, I'll spend twice the amount of time boxing up the whole useless thing and send it back with a 1-star review of part of a chair I never got to sit in. I will go so far as to include a picture of what their injection molding and quality assurance process missed though. I will be hesitant to buy again. It makes me wonder if there aren't missing structures and supports that don't impede the assembly process."
1,Not use able,the cabinet dot were all detached from backing... got me
2,The product is junk.,I received my first order of this product and it was broke so I ordered it again. The second one was broke in more places than the first. I can't blame the shipping process as it's shrink wrapped and boxed.
3,Fucking waste of money,"This product is a piece of shit. Do not buy. Doesn't work, and then I try to call for customer support, it won't take my number. Fucking rip off!"
4,bubble,went through 3 in one day doesn't fit correct and couldn't get bubbles out (better without)


## Clustering

In [15]:
# cohere api
co = cohere.Client(api_key)
title = 'K-Means clustering in 2D umap visualisation'

### Helper Functions

In [16]:
def get_embeddings(df: pd.DataFrame) -> Tuple[list, list]:
  '''
  input:
  - df: dataframe with 'text' column

  output:
  - embeds: cohere embedding
  - umap_embeds: umap embeddings -> dimensionality reduction technique
  '''
  embeds = co.embed(texts=list(df['text']),
                  model=model_size,
                  truncate='LEFT').embeddings
  reducer = umap.UMAP(n_neighbors=100) 
  umap_embeds = reducer.fit_transform(embeds)
  return (embeds, umap_embeds)

In [17]:
def get_keywords(df: pd.DataFrame, n_clusters: int=8) -> pd.DataFrame:
  '''
  inputs:
  - df: dataframe with columns ('text', 'title', 'embeds', 'x', 'y')
  - n_clusters: number of clusters in k-means

  outputs:
  - df: dataframe with columns ('text', 'topic', 'embeds', 'x', 'y', 'cluster', 'keywords')
  - chart-title
  '''

  # k-means clustering
  kmeans_model = KMeans(n_clusters=n_clusters, random_state=0)
  classes = kmeans_model.fit_predict(list(df['embeds'].values))

  # get keywords from each cluster
  # - group documents by cluster assignment
  # - get tf-id for the topic words in each cluster 
  documents =  df['title']
  documents = pd.DataFrame({"Document": documents,
                            "ID": range(len(documents)),
                            "Topic": None})
  documents['Topic'] = classes
  documents_per_topic = documents.groupby(['Topic'], as_index=False).agg({'Document': ' '.join})
  count_vectorizer = CountVectorizer(stop_words="english").fit(documents_per_topic.Document)
  count = count_vectorizer.transform(documents_per_topic.Document)
  words = count_vectorizer.get_feature_names()
  ctfidf = ClassTFIDF().fit_transform(count).toarray()
  words_per_class = {label: [words[index] for index in ctfidf[label].argsort()[-10:]] for label in documents_per_topic.Topic}

  # add cluster assignment and keywords per cluster to dataframe
  df['cluster'] = classes
  df['keywords'] = df['cluster'].map(lambda topic_num: ", ".join(np.array(words_per_class[topic_num])[:]))

  return df

### Processing & Plotting

In [18]:
# get cohere and umap embeddings
embeds, umap_embeds = get_embeddings(df)
# store umap embeddings in dataframe for plotting
df['embeds'] = embeds
df['x'] = umap_embeds[:,0]
df['y'] = umap_embeds[:,1]

In [26]:
# get cluster keywords
df = get_keywords(df, n_clusters=6)

# Plot 2D cluster plot
selection = alt.selection_multi(fields=['keywords'], bind='legend')

chart = alt.Chart(df).transform_calculate(
    url='https://news.ycombinator.com/item?id=' + alt.datum.id
).mark_circle(size=60, stroke='#666', strokeWidth=1, opacity=0.3).encode(
    x=#'x',
    alt.X('x',
        scale=alt.Scale(zero=False),
        axis=alt.Axis(labels=False, ticks=False, domain=False)
    ),
    y=
    alt.Y('y',
        scale=alt.Scale(zero=False),
        axis=alt.Axis(labels=False, ticks=False, domain=False)
    ),
    href='url:N',
    color=alt.Color('keywords:N', 
                    legend=alt.Legend(columns=1, symbolLimit=0, labelFontSize=14)
                   ),
    opacity=alt.condition(selection, alt.value(1), alt.value(0.2)),
    tooltip=['title', 'keywords', 'cluster']
).properties(
    width=800,
    height=500
).add_selection(
    selection
).configure_legend(labelLimit= 0).configure_view(
    strokeWidth=0
).configure(background="#FDF7F0").properties(
    title=title
)
chart.interactive()

## Semantic Search

In [20]:
# resource: https://docs.cohere.ai/semantic-search

In [21]:
query = 'which products are broken?'

### Helper Functions

In [22]:
def search(df: pd.DataFrame, query: str, n_relevantDocs: int=20) -> pd.DataFrame:
  '''
  inputs:
  - df: dataframe with embeds column
  - query: search query
  - n_relevantDocs: number of documents to return for query

  outputs:
  - df: dataframe with additional collumn 'relevance' in [0, 1]
  '''

  # query and embedding
  temp_dict = {'text': query}
  df_query = pd.DataFrame(temp_dict, index=[0])

  # embed query
  query_embed, query_umap_embed = get_embeddings(df_query)

  # create search index
  embeds =  np.array(list(df['embeds'].values))

  search_index = AnnoyIndex(embeds.shape[1], 'angular')
  # Add all the vectors to the search index
  for i in range(len(embeds)):
      search_index.add_item(i, embeds[i])

  search_index.build(10) # 10 trees
  search_index.save('test.ann')

  # Retrieve the nearest neighbors
  similar_item_ids = search_index.get_nns_by_vector(query_embed[0],
                                                    n_relevantDocs,
                                                    include_distances=True)
  # Format the results
  results = pd.DataFrame(data={'texts': df.iloc[similar_item_ids[0]]['text'],
                              'distance': similar_item_ids[1]})

  # dataframe for plotting -> (x, y, relevance) with relevance 0 or 1
  relevant_docs = []

  for k in range(len(df)):
    if k in similar_item_ids[0]:
      relevant_docs.append(1)
    else:
      relevant_docs.append(0)

  df_relevantDocs = pd.DataFrame(relevant_docs, columns=['relevance'])

  df = df.join(df_relevantDocs);

  return df

### Processing & Plotting

In [23]:
# semantic search
df = search(df, query)

In [24]:
# Plot
chart = alt.Chart(df).transform_calculate(
    url='https://news.ycombinator.com/item?id=' + alt.datum.id
).mark_circle(size=60, stroke='#666', strokeWidth=1, opacity=0.3).encode(
    x=#'x',
    alt.X('x',
        scale=alt.Scale(zero=False),
        axis=alt.Axis(labels=False, ticks=False, domain=False)
    ),
    y=
    alt.Y('y',
        scale=alt.Scale(zero=False),
        axis=alt.Axis(labels=False, ticks=False, domain=False)
    ),
    color=alt.Color('relevance', scale=alt.Scale(domain=[0, 1], range=['blue', 'red'])),
    tooltip=['title']
).properties(
    width=800,
    height=500
).configure_legend(labelLimit= 0).configure_view(
    strokeWidth=0
).configure(background="#FDF7F0")

chart.interactive()