# Fine-Tuning a Reranker using Cohere

### This demo will show you how to:
1. Generate synthetic data using DSPy
2. Export all your data from your Weaviate instance
3. Steps to fine-tune a reranker using Cohere
4. Query in Weaviate using your fine-tuned reranker model

#### Note: To fine-tune a model in Cohere, you need to have a minimum of 256 unique queries with at least 1 relevant passage per query. If you already have a dataset with query + relevant passages, you can skip to the end of the notebook!

#### For a full walkthrough of the demo, check out the [complimenting blog post](https://weaviate.io/blog/fine-tuning-coheres-reranker)!

## Connect to Weaviate Instance

In [14]:
import weaviate
import json 

client = weaviate.Client(
    url = "WEAVIATE_URL",  # Replace with your cluster url
    auth_client_secret=weaviate.AuthApiKey(api_key="AUTH_KEY"),  # Replace w/ your Weaviate instance API key
    additional_headers = {
        "X-Cohere-Api-Key": "API-KEY" # Replace with your inference API key
    }
)

## Load in Data

In [7]:
import os
import re

def chunk_list(lst, chunk_size):
    """Break a list into chunks of the specified size."""
    return [lst[i:i + chunk_size] for i in range(0, len(lst), chunk_size)]

def split_into_sentences(text):
    """Split text into sentences using regular expressions."""
    sentences = re.split(r'(?<!\w\.\w.)(?<![A-Z][a-z]\.)(?<=\.|\?)\s', text)
    return [sentence.strip() for sentence in sentences if sentence.strip()]

def read_and_chunk_index_files(main_folder_path):
    """Read index.md files from subfolders, split into sentences, and chunk every 5 sentences."""
    blog_chunks = []
    for folder_name in os.listdir(main_folder_path):
        subfolder_path = os.path.join(main_folder_path, folder_name)
        if os.path.isdir(subfolder_path):
            index_file_path = os.path.join(subfolder_path, 'index.mdx')
            if os.path.isfile(index_file_path):
                with open(index_file_path, 'r', encoding='utf-8') as file:
                    content = file.read()
                    sentences = split_into_sentences(content)
                    sentence_chunks = chunk_list(sentences, 5)
                    sentence_chunks = [' '.join(chunk) for chunk in sentence_chunks]
                    blog_chunks.extend(sentence_chunks)
    return blog_chunks

# Example usage
main_folder_path = './blog'
blogs = read_and_chunk_index_files(main_folder_path)

## Define Schema

If you need to reset your schema and delete objects in a collection, run:
`client.schema.delete_all()` or `client.schema.delete_class("Blogs")`

In [28]:
schema = {
   "classes": [
       {
           "class": "Blogs",
           "description": "Weaviate blogs",
           "vectorizer": "text2vec-cohere",
           "properties": [
               {
                   "name": "content",
                   "dataType": ["text"],
                   "description": "Content from the blogs.",
               },
               {
                   "name": "synthetic_query",
                   "dataType": ["text"],
                   "description": "Synthetic query generated from a LM."
               }
           ]
       }      
   ]
}
    
client.schema.create(schema)

## Import 

In [None]:
%pip install dspy-ai > /dev/null

#### To generate the synthetic queries, we will use DSPy's signature and chain-of-thought module. 

In [None]:
import dspy
import cohere
cohere_key = "API-KEY"

class WriteQuery(dspy.Signature):
    """Write a query that this document would have the answer to."""

    document = dspy.InputField(desc="A document containing information.") 
    query = dspy.OutputField(desc="A short question uniquely answered by the document.")

command_nightly = dspy.Cohere(model="command-nightly",max_tokens=1000, api_key=cohere_key)

for blog_chunk in blogs:
    with dspy.context(lm=command_nightly):
        llm_query = dspy.ChainOfThought(WriteQuery)(document=blog_chunk)
    print(llm_query)
    data_properties = {
        "content": blog_chunk,
        "synthetic_query": llm_query.query
    }
    print(f"{data_properties}\n")
    client.data_object.create(data_properties, "Blogs")

#### Here is one example of the chain-of-thought module in DSPy. It is taking my initial signature (prompt) and putting the first blog chunk into the prompt.

In [35]:
with dspy.context(lm=command_nightly):
    dspy.ChainOfThought(WriteQuery)(document=blogs[0]).query
command_nightly.inspect_history(n=1)





Write a query that this document would have the answer to.

---

Follow the following format.

Document: A document containing information.
Reasoning: Let's think step by step in order to ${produce the query}. We ...
Query: A short question uniquely answered by the document.

---

Document: --- title: Combining LangChain and Weaviate slug: combining-langchain-and-weaviate authors: [erika] date: 2023-02-21 tags: ['integrations'] image: ./img/hero.png description: "LangChain is one of the most exciting new tools in AI. It helps overcome many limitations of LLMs, such as hallucination and limited input lengths." --- ![Combining LangChain and Weaviate](./img/hero.png) Large Language Models (LLMs) have revolutionized the way we interact and communicate with computers. These machines can understand and generate human-like language on a massive scale. LLMs are a versatile tool that is seen in many applications like chatbots, content creation, and much more. Despite being a powerful tool, 

## Export the data from your Weaviate instance

To fine-tune the model, we need to export our data and upload it to Cohere's reranker. 

In [18]:
'''
This example will show you how to get all of your data
out of Weaviate and into a JSON file using the Cursor API!
'''
import time
start = time.time()

# Step 1 - Get the UUID of the first object inserted into Weaviate

get_first_object_weaviate_query = """
{
  Get {
    Blogs {
      _additional {
        id
      }
    }
  }
}
"""

results = client.query.raw(get_first_object_weaviate_query)
uuid_cursor = results["data"]["Get"]["Blogs"][0]["_additional"]["id"]

# Step 2 - Get the Total Objects in Weaviate

total_objs_query = """
{
    Aggregate {
        Blogs {
            meta {
                count
            }
        }
    }
}
"""

results = client.query.raw(total_objs_query)
total_objects = results["data"]["Aggregate"]["Blogs"][0]["meta"]["count"]

# Step 3 - Iterate through Weaviate with the Cursor
increment = 50
data = []
for i in range(0, total_objects, increment):
    results = (
        client.query.get("Blogs", ["content", "synthetic_query"])
        .with_additional(["id"])
        .with_limit(50)
        .with_after(uuid_cursor)
        .do()
    )["data"]["Get"]["Blogs"]
    # extract data from result into JSON
    for result in results:
        if len(result["synthetic_query"]) < 5:
            continue
        new_obj = {}
        for key in result.keys():
            if key == "_additional":
                continue
            if key == "synthetic_query":
                new_obj["query"] = result[key]
            if key == "content":
                new_obj["relevant_passages"] = [result[key]]
        data.append(new_obj)
    # update uuid cursor to continue the loop
    # we have just exited a loop where result holds the last obj
    uuid_cursor = result["_additional"]["id"]

# save JSON
file_path = "data.jsonl"
with open(file_path, 'w') as jsonl_file:
    for item in data:
        jsonl_file.write(json.dumps(item) + '\n')

print("Your data is out of Weaviate!")
print(f"Extracted {total_objects} in {time.time() - start} seconds.")

Your data is out of Weaviate!
Extracted 405 in 0.6107177734375 seconds.


In [21]:
data = []
with open("data.jsonl", "r") as file:
    for line in file:
        data.append(json.loads(line))

split_index = int(len(data)*0.8)
train_data = data[:split_index]
validation_data = data[split_index:]

with open("./train.jsonl", "w") as train_file:
    for line in train_data:
        train_file.write(json.dumps(line) + "\n")

with open("./validation.jsonl", "w") as validation_file:
    for line in validation_data:
        validation_file.write(json.dumps(line) + "\n")


## Fine-Tune on Cohere

### You will need to upload your training and validation dataset to [Cohere](https://dashboard.cohere.com/fine-tuning/create?endpoint=rerank). Once you start the training, it will take a few hours (estimate) to train and output your `model_id`. You will also have access to a dashboard that shows the model performance (screenshot below).

![Cohere Dashboard](cohere-dashboard.png)

## Re-Index Data 

In order to use our fine-tuned reranker, we will need to upload our data again to a new collection and add the `model_id`.

In [16]:
schema = {
   "classes": [
       {
           "class": "BlogsFineTuned",
           "description": "Weaviate blogs",
           "vectorizer": "text2vec-cohere",
           "moduleConfig": {
                "reranker-cohere": {
                    "model": "model_id" # grab the model_id from Cohere
                }
           },
           "properties": [
               {
                   "name": "content",
                   "dataType": ["text"],
                   "description": "Content from the blogs.",
               }
           ]
       }      
   ]
}
    
client.schema.create(schema)

### Upload data (same as above)

In [27]:
for blog in blogs:
    data_properties = {
        "content": blog
    }
    client.data_object.create(
        data_object = data_properties,
        class_name = "BlogsFineTuned"
    )

## Query Time

#### Query without Reranking

In [15]:
response = (
    client.query
    .get("BlogsFineTuned", ["content"])
    .with_near_text({
        "concepts": ["Ref2Vec in Weaviate"]
    })
    .with_limit(5)
    .do()
)

print(json.dumps(response, indent=2))

{
  "data": {
    "Get": {
      "BlogsFineTuned": [
        {
          "content": "---\ntitle: What is Ref2Vec and why you need it for your recommendation system\nslug: ref2vec-centroid\nauthors: [connor]\ndate: 2022-11-23\ntags: ['integrations', 'concepts']\nimage: ./img/hero.png\ndescription: \"Weaviate introduces Ref2Vec, a new module that utilises Cross-References for Recommendation!\"\n---\n![Ref2vec-centroid](./img/hero.png)\n\n<!-- truncate -->\n\nWeaviate 1.16 introduced the [Ref2Vec](/developers/weaviate/modules/retriever-vectorizer-modules/ref2vec-centroid) module. In this article, we give you an overview of what Ref2Vec is and some examples in which it can add value such as recommendations or representing long objects. ## What is Ref2Vec? The name Ref2Vec is short for reference-to-vector, and it offers the ability to vectorize a data object with its cross-references to other objects. The Ref2Vec module currently holds the name ref2vec-**centroid** because it uses the avera

### Query with Reranking

In [16]:
response = (
    client.query
    .get("BlogsFineTuned", ["content"])
    .with_near_text({
        "concepts": ["Ref2Vec in Weaviate"]
    })
    .with_additional("rerank(property: \"content\" query: \"Represent users based on their product interactions\") { score }")
    .with_limit(5)
    .do()
)

print(json.dumps(response, indent=2))

{
  "data": {
    "Get": {
      "BlogsFineTuned": [
        {
          "_additional": {
            "rerank": [
              {
                "score": 0.9261703
              }
            ]
          },
          "content": "![Cross-reference](./img/Weaviate-Ref2Vec_1.png)\n\nRef2Vec gives Weaviate another way to vectorize a class, such as the User class, based on their relationships to other classes. This allows Weaviate to quickly create up-to-date representations of users based on their relationships such as recent interactions. If a user clicks on 3 shoe images on an e-commerce store, it is a safe bet that they want to see more shoes. Ref2Vec captures this intuition by calculating vectors that aggregate each User's interaction with another class. The below animation visualizes a real example of this in e-Commerce images."
        },
        {
          "_additional": {
            "rerank": [
              {
                "score": 0.34444344
              }
            ]
   