<img src="https://relevance.ai/wp-content/uploads/2021/11/logo.79f303e-1.svg" width="150" alt="Relevance AI" />
<h5> Developer-first vector platform for ML teams </h5>

[![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/RelevanceAI/workflows/blob/main/workflows/vector-rake/vector_rake.ipynb)

# Installation ⚡

In [None]:
%%capture
!pip install RelevanceAI==2.1.4
!pip install rake-nltk
!pip install -U vectorhub[clip]

# Set up 🛠

## Variables 📑

In [None]:
min_ngram_cnt = 0  # minimum number of words in a selected key-phrase
top_n_single = 2   # number of key phrases to select from each entry
top_n_all = 5      # number of key phrases to assign to each entry from the overall set of keyphrases
key_word_cnt = 50  # number of key phrases to select from the whole dataset

## Client 🤖

In [None]:
from relevanceai import Client

client = Client()


## Encoder 🦾

In [None]:
from vectorhub.encoders.text.sentence_transformers import SentenceTransformer2Vec

text_model = SentenceTransformer2Vec('clip-ViT-B-32')

# VectorRake class 📚

In [None]:
import nltk
from nltk import pos_tag
nltk.download('punkt')
nltk.download('stopwords')
nltk.download('averaged_perceptron_tagger')
stopwords = nltk.corpus.stopwords.words('english')

import string
import random
import numpy as np
from sklearn.metrics.pairwise import cosine_distances
from rake_nltk import Rake
from collections import Counter
from typing import List


class VectorRake:
    def __init__(self):
        self.rake_nltk_var = Rake()

    def rake_per_text(self, text:str, min_ngram_cnt:int):
        # returns keywords extracted by Rake for each text piece
        self.rake_nltk_var.extract_keywords_from_text(text)
        extracted_keywords = self.rake_nltk_var.get_ranked_phrases()
        keywords = [kw for kw in extracted_keywords if len(kw.split(' ')) >= min_ngram_cnt]
        keywords = [kw for kw in keywords if len(kw)>2]
        return keywords

    def min_cosine_distances_text_to_keywords(self, text_vec:List, keywrds_vecs:List, top_n:int= 2):
        # picks a subset of keywords based on the cosine distance between the doc_keyword and the doc text
        cosine_dist = cosine_distances(keywrds_vecs, text_vec).reshape(1,-1)
        min_dist_idxs = np.argsort(cosine_dist)[0][:top_n]
        return min_dist_idxs

    def min_cosine_distances_docs_to_keywords(self, docs_vec:List, keywrds_vecs:List, top_n:int= 5):
        # picks a subset of all keywords for each doc, based on the coside distance
        cosine_dist = cosine_distances(docs_vec, keywrds_vecs)
        closest_topn_index = np.argsort(cosine_dist, axis=1)[:, :top_n]
        return closest_topn_index

    def extract_if_puntuation(self, text_list:List):
        # returns text pieces that include punctuation
        include_puntuation = set()
        for text in text_list:
            if any([True for ch in text if ch in string.punctuation]):
                include_puntuation.add(text)
        return include_puntuation

    def extract_if_dgits(self, text_list:List):
      # returns text pieces that include digits
        include_digits = set()
        for text in text_list:
            if any([True for ch in text if ch in string.digits]):
                include_digits.add(text)
        return include_digits

    def get_substrings(self, text_list:List):
        # returns text pieces that are included in others 
        # e.g. "service" is included in "good service"
        overlap = set()
        seen_once = [itm[0] for itm in text_list.items() if itm[1]==1]
        for h1 in seen_once[:len(seen_once)//2]:
            for h2 in seen_once[len(seen_once)//2:]:
                if len(h2)<len(h1):
                    if h2 in h1:
                        overlap.add(h2)
                else:
                    if h1 in h2:
                        overlap.add(h1)
        return overlap

    def started_with_adjective(self, text_list:List):
        # returns text pieces starting with an adjective
        adj_list = []
        for kw in text_list:
            tokens_tag = pos_tag(kw.split())
            if tokens_tag[0][1] in ['JJ', 'JJR', 'JJS']:
                adj_list.append(kw)
        return adj_list

    def remove_plurals(self, text_list:List):
        #todo: remove plurals from the list
        return text_list

    def process_key_words_naive(self, all_keywords:List, n:int):
        # some basic text processing
        all_kwrds_cnt_dict = dict(Counter(all_keywords))
        to_skip = set()

        # no punc and digits
        to_skip = to_skip.union(self.extract_if_puntuation(all_kwrds_cnt_dict))
        to_skip = to_skip.union(self.extract_if_dgits(all_kwrds_cnt_dict))

        # repetition
        selected_kwrds = [itm[0] for itm in all_kwrds_cnt_dict.items() if itm[1]>1]

        # word overlap
        overlap = self.get_substrings(all_kwrds_cnt_dict)
        selected_kwrds.extend(list(overlap))

        left_unused = set([itm for itm in all_kwrds_cnt_dict]) - set(selected_kwrds) - overlap - to_skip

        # select adjective first
        selected_kwrds.extend(self.started_with_adjective(left_unused))
        left_unused -= set(selected_kwrds)

        # random select
        if len(selected_kwrds) < n:
            if n-len(selected_kwrds)<len(left_unused):
              selected_kwrds.extend(random.sample(left_unused, n-len(selected_kwrds)))
        else:
            selected_kwrds = random.sample(selected_kwrds, n)
        
        return selected_kwrds

    def vector_rake(self, docs:List, text_model, text_field:str, 
                    vector_field:str, min_ngram_cnt:int= 0, 
                    top_n_single:int = 2,
                    top_n_all:int = 5,
                    n:int = 50):
        # Updates docs with a subset of keywords selected using Rake and Cosine distance in vector space
        
        # key words
        all_keywords = []
        all_keywords_vecs = {}

        for i,d in enumerate(docs):
          if text_field in d and d[text_field] != None:
            doc_keywords = self.rake_per_text(text = d[text_field], min_ngram_cnt = min_ngram_cnt)
            if doc_keywords != []:
              doc_keywrds_vecs = [text_model.encode(kw) for kw in doc_keywords]
              min_dist_idxs = self.min_cosine_distances_text_to_keywords(
                  text_vec = [d[vector_field]], 
                  keywrds_vecs = doc_keywrds_vecs,
                  top_n = top_n_single)
              all_keywords.extend([doc_keywords[idx] for idx in min_dist_idxs])
                                
              for j,kw in enumerate(doc_keywords):
                all_keywords_vecs[kw]=doc_keywrds_vecs[j] 

        keywords_info = [{"_id": i,
                            "label": w, 
                            "_label_vector_":all_keywords_vecs[w] if isinstance(all_keywords_vecs[w], List) else all_keywords_vecs[w].tolist()} 
                            for i,w in enumerate(self.process_key_words_naive(all_keywords, n))]

        docs_vec = [d[vector_field] for d in docs if vector_field in d]
        keywrds_vecs = [d["_label_vector_"] for d in keywords_info]
        closest_topn_index = self.min_cosine_distances_docs_to_keywords(docs_vec, keywrds_vecs, top_n = top_n_all)

        count = 0
        for d in docs:
            if vector_field in d:
                tags = []
                for ind in closest_topn_index[count]:
                    tags.append(keywords_info[ind]["label"])
                if "_label_" not in d:
                  d["_label_"] = {}
                d["_label_"][text_field+"_vector_rake"] = tags
                count += 1

        return docs    


# Data

In [None]:
from relevanceai.datasets import get_ecommerce_dataset_encoded

docs = get_ecommerce_dataset_encoded()

text_field = "product_title"
vector_field='product_title_clip_vector_'

### Label 🏷

In [None]:
vr = VectorRake()
for d in docs:
  if '_label_' in d:
    del d['_label_']

docs = vr.vector_rake(docs, 
                      text_model = text_model, 
                      text_field = text_field, 
                      vector_field = vector_field,
                      n = key_word_cnt,
                      min_ngram_cnt = min_ngram_cnt, 
                      top_n_single= top_n_single,
                      top_n_all = top_n_all)

# Sample 🟦

In [None]:
d = docs[150]
print(d[text_field])
print(d['_label_']['product_title_vector_rake'])