In [None]:
#Part 2: Embedding Search (with CLIP)
#Created by: Eric Martinez
#For: CSCI 4341
#At: University of Texas Rio-Grande Valley
#Install the required dependencies.

In [None]:
%pip install -q -r requirements.txt

In [None]:
%pip install git+https://github.com/openai/CLIP.git

In [None]:
#Step 1: Create Helpful Functions for Working with CLIP

In [None]:
import torch
import clip

device = "cuda" if torch.cuda.is_available() else "cpu"
model, preprocess = clip.load("ViT-B/32", device=device)

def encode_text(text):
    text_tokens = clip.tokenize([text]).to(device)
    with torch.no_grad():
        text_embedding = model.encode_text(text_tokens)[0]
    return text_embedding.tolist()

def encode_image(image):
    image = preprocess(image).unsqueeze(0).to(device)

    with torch.no_grad():
        image_features = model.encode_image(image)[0]
    return image_features.tolist()

In [None]:
#Let's try out the text embeddings

In [None]:
text = "a cute puppy"
text_embedding = encode_text(text)

print(len(text_embedding))

In [None]:
#Now image embedding

In [None]:
from PIL import Image
import requests

url = "https://media.istockphoto.com/id/157431311/photo/turkey-sandwich.jpg?s=612x612&w=0&k=20&c=uB6byErFAnWxFkkAqMiGNRJGE8r3nqsSDdqrfBE8HOA="
raw_image = Image.open(requests.get(url, stream=True).raw)

image_embedding = encode_image(raw_image)
print(len(image_embedding))

In [None]:
display(raw_image)

In [None]:
#Step 2: Create Custom Chroma Embedding Function for CLIP

In [None]:
from chromadb import Documents, EmbeddingFunction, Embeddings

class CLIPEmbeddingFunction(EmbeddingFunction):
    def __call__(self, texts: Documents) -> Embeddings:
        embeddings = []
        
        for text in texts:
            text_embedding = encode_text(text)
            embeddings.append(text_embedding)
            
        return embeddings

In [None]:
#Step 3: Create new Chroma Collection

In [None]:
from dotenv import load_dotenv
load_dotenv()  # take environment variables from .env.
import os

import chromadb
from chromadb.utils import embedding_functions


def get_chroma_collection(collection_name):
    ## Use this one to save to memory
    # chroma_client = chromadb.Client() 

    ## Use this one to save to disk
    chroma_client = chromadb.PersistentClient(path=".")

    clip_ef = CLIPEmbeddingFunction(device=device)

    collection = chroma_client.get_or_create_collection(name=collection_name, embedding_function=clip_ef)
    return collection

In [None]:
collection = get_chroma_collection("food_clip")

In [None]:
#Step 4: Add Data to Chroma Collection

In [None]:
import json

def load_data():
    with open("data.json") as f:
        data = json.load(f)
    return data

In [None]:
def add_data_to_collection(data, collection):
    documents = []
    metadatas = []
    ids = []

    for i, restaurant in enumerate(data):
        name = restaurant['name']
        address = restaurant['address']
        # TODO: add the other fields
        rating = restaurant['rating']
        description of restaurant = restaurant['description of restaurant']
        

        # TODO: what are we embedding for each restaurant - obviously add to this
        embeddable_string = f"{name} {address} {rating} {description_of_restaurant}"
        documents.append(embeddable_string)

        # lets just store everything we have as metadata
        metadatas.append(restaurant)

        # lets use the index as the id
        ids.append(str(i))

    collection.add(
        documents=documents,
        metadatas=metadatas,
        ids=ids
    )

In [None]:
data = load_data()
add_data_to_collection(data, collection)

In [None]:
#Step 4: Query the Collection
#make some helper function to query

In [None]:
def get_results_by_text(query, n_results=2):
    metadatas = []
    n_results = 2
    results = collection.query(query_texts=[query], n_results=2)
    
    for i in range(n_results):
        metadatas.append(results["metadatas"][0][i])
        
    return metadatas

def get_results_by_image(image, n_results=2):
    metadatas = []
    n_results = 2
    
    image_embedding = encode_image(image)
    results = collection.query(query_embeddings=[image_embedding], n_results=2)
    
    for i in range(n_results):
        metadatas.append(results["metadatas"][0][i])
        
    return metadatas

In [None]:
#test out querying by text

In [None]:
results = get_results_by_text("fajita", n_results=2)

for result in results:
    print(result)

In [None]:
#test out querying by image

In [None]:
# sandwhich image
url = "https://media.istockphoto.com/id/157431311/photo/turkey-sandwich.jpg?s=612x612&w=0&k=20&c=uB6byErFAnWxFkkAqMiGNRJGE8r3nqsSDdqrfBE8HOA="
raw_image = Image.open(requests.get(url, stream=True).raw)

results = get_results_by_image(raw_image, n_results=2)

for result in results:
    print(result)

In [None]:
#Step 5: Build the Gradio UI

In [None]:
from dotenv import load_dotenv
load_dotenv()  # take environment variables from .env.
import gradio as gr
import openai
import pandas as pd
import gradio as gr

def search_by_text(query, n_results):
    results = get_results_by_text(query, n_results=n_results)

    try:
        df = pd.DataFrame(results, columns=['name', 'address', 'rating', 'description of restaurant'])
        return df
    except Exception as e:
        raise gr.Error(e.message)
        
        
def search_by_image(image, n_results):
    results = get_results_by_image(image, n_results=n_results)
    
    try:
        df = pd.DataFrame(results, columns=['name', 'address', 'rating', 'description of restaurant'])
        return df
    except Exception as e:
        raise gr.Error(e.message)

with gr.Blocks() as demo:
    with gr.Tab("Search by Text"):
        with gr.Row():
            with gr.Column():
                query = gr.Textbox(label="What are you looking for?", lines=5)
                text_n_results = gr.Slider(label="Results to Display", minimum=0, maximum=10, value=2, step=1)
                text_btn = gr.Button(value ="Submit")
                text_table = gr.Dataframe(label="Results", headers=['name', 'address', 'rating', 'description of restaurant'])
            text_btn.click(search_by_text, inputs = [query, text_n_results], outputs = [text_table])
            
    with gr.Tab("Search by Image"):
        with gr.Row():
            with gr.Column():
                image = gr.Image(label="Upload a picture", type='pil')
                image_n_results = gr.Slider(label="Results to Display", minimum=0, maximum=10, value=2, step=1)
                image_btn = gr.Button(value ="Submit")
                image_table = gr.Dataframe(label="Results", headers=['name', 'address', 'rating', 'description of restaurant'])
            image_btn.click(search_by_image, inputs = [image, image_n_results], outputs = [image_table])
    demo.launch(share=True)