<img src = "https://github.com/VeryFatBoy/notebooks/blob/main/common/images/img_github_singlestore-jupyter_featured_2.png?raw=true">

<div id="singlestore-header" style="display: flex; background-color: rgba(235, 249, 245, 0.25); padding: 5px;">
    <div id="icon-image" style="width: 90px; height: 90px;">
        <img width="100%" height="100%" src="https://raw.githubusercontent.com/singlestore-labs/spaces-notebooks/master/common/images/header-icons/browser.png" />
    </div>
    <div id="text" style="padding: 5px; margin-left: 10px;">
        <div id="badge" style="display: inline-block; background-color: rgba(0, 0, 0, 0.15); border-radius: 4px; padding: 4px 8px; align-items: center; margin-top: 6px; margin-bottom: -2px; font-size: 80%">SingleStore Notebooks</div>
        <h1 style="font-weight: 500; margin: 8px 0 0 4px;">AI-Powered Personalized Shopping & Reco Engine with OpenAI CLIP</h1>
    </div>
</div>

In [7]:
!pip cache purge

Files removed: 451


In [8]:
!pip install git+https://github.com/openai/CLIP.git --quiet
!pip install torch --quiet

In [12]:
import clip
import numpy as np
import pandas as pd
import requests
import time
import torch
import warnings

from io import BytesIO
from IPython.display import Image, display
from PIL import Image as PILImage
from tqdm import tqdm

warnings.filterwarnings("ignore")

In [13]:
# Load CLIP model and preprocess function
device = "cuda" if torch.cuda.is_available() else "cpu"
model, preprocess = clip.load("ViT-B/32", device = device)

100%|███████████████████████████████████████| 338M/338M [00:04<00:00, 88.0MiB/s]


In [14]:
url = "https://github.com/singlestore-labs/genai-app-example/raw/refs/heads/main/data/products-1.json"

df = pd.read_json(url)

In [16]:
df.head()

Unnamed: 0,id,created_at,title,image,description,price,gender,title_v,description_v,type_id
0,1,2024-05-02 17:40:30,logo print strap sandals,https://cdn-images.farfetch-contents.com/13/41...,"This product is a type of women's sandal, feat...",1080,women,"[-0.026900072,-0.008780068,-0.0051244507,-0.03...","[-0.009948222,-0.0034086716,-0.0057584653,-0.0...",1
1,2,2024-05-02 17:40:30,embroidered midi dress,https://cdn-images.farfetch-contents.com/13/78...,This is a red floral lace maxi dress. It is de...,392,women,"[-0.05438027,-0.02859787,-0.006243564,-0.01661...","[-0.025190884,0.0064714067,-0.009578454,-0.019...",2
2,3,2024-05-02 17:40:30,FendiMania sock style sneakers,https://cdn-images.farfetch-contents.com/13/37...,This product is a pair of high-fashion sneaker...,1245,women,"[-0.0013927056,-0.009008276,-0.001865071,-0.02...","[0.0058779274,-0.007137483,-0.017122088,-0.017...",3
3,4,2024-05-02 17:40:30,top zip wallet,https://cdn-images.farfetch-contents.com/13/74...,"The pictured product is a wallet, designed in ...",308,women,"[-0.018660557,-0.0040859696,-0.00568928,-0.036...","[-0.007549583,0.010181303,-0.010712964,-0.0347...",4
4,5,2024-05-02 17:40:30,slingback 65 pumps,https://cdn-images.farfetch-contents.com/13/66...,This product is a women's high-heeled shoe fea...,1248,women,"[-0.023413643,-0.00639263,-0.0045952327,-0.022...","[-0.011407746,-0.014593455,-0.02141431,-0.0313...",5


In [17]:
df_small = df.head(10)

In [18]:
df_small.drop(columns = ["title_v", "description_v"], inplace = True)

In [19]:
df_small.head()

Unnamed: 0,id,created_at,title,image,description,price,gender,type_id
0,1,2024-05-02 17:40:30,logo print strap sandals,https://cdn-images.farfetch-contents.com/13/41...,"This product is a type of women's sandal, feat...",1080,women,1
1,2,2024-05-02 17:40:30,embroidered midi dress,https://cdn-images.farfetch-contents.com/13/78...,This is a red floral lace maxi dress. It is de...,392,women,2
2,3,2024-05-02 17:40:30,FendiMania sock style sneakers,https://cdn-images.farfetch-contents.com/13/37...,This product is a pair of high-fashion sneaker...,1245,women,3
3,4,2024-05-02 17:40:30,top zip wallet,https://cdn-images.farfetch-contents.com/13/74...,"The pictured product is a wallet, designed in ...",308,women,4
4,5,2024-05-02 17:40:30,slingback 65 pumps,https://cdn-images.farfetch-contents.com/13/66...,This product is a women's high-heeled shoe fea...,1248,women,5


In [20]:
def preprocess_image(image_url, timeout = 30, retries = 3, backoff = 2):
    attempt = 0
    while attempt < retries:
        try:
            response = requests.get(image_url, timeout = timeout)
            response.raise_for_status()

            image_tensor = preprocess(
                PILImage.open(
                    BytesIO(response.content)
                )
            ).unsqueeze(0).to(device)

            return image_tensor
        except Exception as e:
            print(f"Attempt {attempt + 1} failed: {e}. Retrying in {backoff} seconds...")
            time.sleep(backoff)
            attempt += 1
            backoff *= 2
    print(f"Failed to download image from {image_url} after {retries} attempts.")
    return np.nan

In [21]:
def preprocess_text(text, max_words = 70):
    words = text.split()
    if len(words) > max_words:
        text = " ".join(words[:max_words])

    text_tensor = clip.tokenize(text).to(device)
    return text_tensor

In [22]:
# Function to generate combined CLIP embedding
def generate_embeddings(row):
    image_tensor = preprocess_image(row["image"])
    if np.isnan(image_tensor).any():
        return np.nan, np.nan, np.nan

    text_tensor = preprocess_text(row["title"])

    with torch.no_grad():
        image_features = model.encode_image(image_tensor)
        text_features = model.encode_text(text_tensor)

        image_features /= image_features.norm(dim = -1, keepdim = True)
        text_features /= text_features.norm(dim = -1, keepdim = True)

        text_vector = text_features.cpu().numpy().flatten()
        image_vector = image_features.cpu().numpy().flatten()
        combined_vector = (text_vector + image_vector) / 2

        return text_vector, image_vector, combined_vector

In [23]:
tqdm.pandas()

df_small[["text_vector", "image_vector", "combined_vector"]] = df_small.progress_apply(
    lambda row: pd.Series(generate_embeddings(row)),
    axis = 1
)

100%|██████████| 10/10 [01:36<00:00, 10.46s/it]

Attempt 1 failed: HTTPSConnectionPool(host='cdn-images.farfetch-contents.com', port=443): Read timed out. (read timeout=30). Retrying in 2 seconds...


100%|██████████| 10/10 [02:19<00:00, 13.99s/it]


In [25]:
df_small.head(10)

Unnamed: 0,id,created_at,title,image,description,price,gender,type_id,text_vector,image_vector,combined_vector
0,1,2024-05-02 17:40:30,logo print strap sandals,https://cdn-images.farfetch-contents.com/13/41...,"This product is a type of women's sandal, feat...",1080,women,1,"[-0.0015065934, -0.090367936, 0.016541064, -0....","[0.014005895, -0.049053773, 0.034701098, 0.042...","[0.006249651, -0.06971085, 0.02562108, 0.02070..."
1,2,2024-05-02 17:40:30,embroidered midi dress,https://cdn-images.farfetch-contents.com/13/78...,This is a red floral lace maxi dress. It is de...,392,women,2,"[0.008111233, -0.028976891, 0.03851886, 0.0523...","[0.0018451426, -0.020285357, 0.022421554, -0.0...","[0.004978188, -0.024631124, 0.030470207, 0.023..."
2,3,2024-05-02 17:40:30,FendiMania sock style sneakers,https://cdn-images.farfetch-contents.com/13/37...,This product is a pair of high-fashion sneaker...,1245,women,3,"[-0.031074192, -0.04274851, 0.05793105, -0.053...","[-0.009132511, -0.0017321664, 0.09899935, -0.0...","[-0.020103352, -0.022240339, 0.0784652, -0.031..."
3,4,2024-05-02 17:40:30,top zip wallet,https://cdn-images.farfetch-contents.com/13/74...,"The pictured product is a wallet, designed in ...",308,women,4,"[-0.0031569542, 0.009241302, -0.029023737, 0.0...","[-0.0011271789, 0.014097556, 0.009860507, 0.03...","[-0.0021420666, 0.011669429, -0.009581614, 0.0..."
4,5,2024-05-02 17:40:30,slingback 65 pumps,https://cdn-images.farfetch-contents.com/13/66...,This product is a women's high-heeled shoe fea...,1248,women,5,"[-0.00018220788, -0.027372295, -0.0020084884, ...","[-0.005171565, -0.03689469, 0.034868743, 0.022...","[-0.0026768863, -0.032133494, 0.016430128, 0.0..."
5,6,2024-05-02 17:40:30,Love Bag shoulder wallet,https://cdn-images.farfetch-contents.com/13/68...,This is a stylish shoulder bag featuring a str...,157,women,6,"[0.02435468, -0.030277371, -0.049517892, 0.017...","[-0.03417675, -0.029232593, -0.020919362, -0.0...","[-0.0049110344, -0.029754981, -0.035218626, 0...."
6,7,2024-05-02 17:40:30,Green Dionysus GG small velvet shoulder bag,https://cdn-images.farfetch-contents.com/12/56...,This is a stylish shoulder bag featuring a lux...,3740,women,7,"[0.06264587, 0.00642968, -0.07207776, 0.015218...","[-0.040029846, -0.01267963, 0.0018165418, 0.02...","[0.011308011, -0.003124975, -0.03513061, 0.017..."
7,8,2024-05-02 17:40:30,logo print satchel,https://cdn-images.farfetch-contents.com/13/76...,This is a designer handbag featuring a stylish...,1323,women,7,"[0.030206611, -0.041917205, -0.04427146, 0.075...","[-0.011056006, -0.015586983, 0.010477726, 0.01...","[0.009575303, -0.028752094, -0.016896868, 0.04..."
8,9,2024-05-02 17:40:30,GG Marmont Matelasse wallet,https://cdn-images.farfetch-contents.com/12/16...,"This item is a compact, black leather wallet. ...",810,women,4,"[-0.019062033, -0.0012936837, -0.00613538, 0.0...","[0.011093294, -0.005519665, -0.009706387, 0.03...","[-0.0039843693, -0.0034066741, -0.007920884, 0..."
9,10,2024-05-02 17:40:30,Blake herringbone midi dress,https://cdn-images.farfetch-contents.com/13/77...,"This is a long-sleeve, beige dress that featur...",989,women,2,"[-0.008873837, -0.022305872, 0.030090122, 0.01...","[-0.019152857, 0.011896507, 0.018871492, 0.022...","[-0.014013347, -0.0052046827, 0.024480807, 0.0..."


In [26]:
df_small.dropna(inplace = True)

In [27]:
dimensions = len(df_small.at[0, "image_vector"])

In [28]:
dimensions

512

In [29]:
%config SqlMagic.named_parameters = True

<div class="alert alert-block alert-warning">
    <b class="fa fa-solid fa-exclamation-circle"></b>
    <div>
        <p><b>Action Required</b></p>
        <p>Select the database from the drop-down menu at the top of this notebook. It updates the <b>connection_url</b> which is used by SQLAlchemy to make connections to the selected database.</p>
    </div>
</div>

In [30]:
from sqlalchemy import *

db_connection = create_engine(connection_url)

In [31]:
%%sql
DROP TABLE IF EXISTS products;

CREATE TABLE IF NOT EXISTS products (
    id BIGINT PRIMARY KEY,
    created_at DATETIME,
    title TEXT,
    image VARCHAR(256),
    description TEXT,
    price DECIMAL(9,2),
    gender VARCHAR(64),
    type_id BIGINT,
    image_vector VECTOR(:dimensions),
    text_vector VECTOR(:dimensions),
    combined_vector VECTOR(:dimensions)
);

In [32]:
df_small.to_sql(
    "products",
    con = db_connection,
    if_exists = "append",
    index = False,
    chunksize = 1000
)

10

In [33]:
random_row = df_small.sample(n = 1)

## Text Query

In [34]:
def get_text_query_vector(text_query, model, device):
    """
    Encodes a text query into a vector using the CLIP model.

    Args:
    - text_query (str): The text query to encode.
    - model: The preloaded CLIP model.
    - device: The device to use ('cpu' or 'cuda').

    Returns:
    - np.ndarray: The text query vector as a NumPy array.
    """
    with torch.no_grad():
        text_query_features = model.encode_text(
            clip.tokenize(text_query).to(device)
        )

    text_query_features /= text_query_features.norm(dim = -1, keepdim = True)

    return text_query_features.cpu().numpy().astype(np.float32)

In [35]:
text_query = random_row["title"].values[0]

text_query_vector = get_text_query_vector(text_query, model, device)

In [36]:
%%sql
SELECT id,
    ROUND(text_vector <*> :text_query_vector, 5) AS similarity
FROM products
ORDER BY similarity DESC
LIMIT 3;

id,similarity
4,1.0
6,0.86425
9,0.79642


## Image Query

In [37]:
def get_image_query_vector(response, model, device, preprocess):
    """
    Encodes an image from a response content into a vector using the CLIP model.

    Args:
    - response: The HTTP response object containing the image content.
    - model: The preloaded CLIP model.
    - device: The device to use ('cpu' or 'cuda').
    - preprocess: The preprocessing function for the CLIP model.

    Returns:
    - np.ndarray: The image query vector as a NumPy array.
    """
    image = preprocess(
        PILImage.open(
            BytesIO(response.content)
        )
    ).unsqueeze(0).to(device)

    with torch.no_grad():
        image_query_features = model.encode_image(image)

    image_query_features /= image_query_features.norm(dim = -1, keepdim = True)

    return image_query_features.cpu().numpy().astype(np.float32)

In [42]:
image_url = random_row["image"].values[0]
response = requests.get(image_url)
display(Image(url = image_url))

image_query_vector = get_image_query_vector(response, model, device, preprocess)

In [43]:
%%sql
SELECT id,
    ROUND(image_vector <*> :image_query_vector, 5) AS similarity
FROM products
ORDER BY similarity DESC
LIMIT 3;

id,similarity
4,1.0
8,0.86505
7,0.8269


## Combined Query

In [44]:
combined_query_vector = (text_query_vector + image_query_vector) / 2

In [45]:
%%sql
SELECT id,
    ROUND(combined_vector <*> :combined_query_vector, 5) AS similarity
FROM products
ORDER BY similarity DESC
LIMIT 3;

id,similarity
4,0.65132
6,0.5684
9,0.55275


## Miscellaneous

In [46]:
text_query = ["pink shoulder bag"]

text_query_vector = get_text_query_vector(text_query, model, device)

In [47]:
%%sql df_query <<
SELECT id, image,
    ROUND(combined_vector <*> :text_query_vector, 5) AS similarity
FROM products
ORDER BY similarity DESC
LIMIT 3;

In [48]:
df_query = pd.DataFrame(df_query)

In [49]:
df_query

Unnamed: 0,id,image,similarity
0,6,https://cdn-images.farfetch-contents.com/13/68...,0.55408
1,8,https://cdn-images.farfetch-contents.com/13/76...,0.51917
2,4,https://cdn-images.farfetch-contents.com/13/74...,0.4966


In [50]:
image_url = df_query.iloc[0]["image"]
response = requests.get(image_url)    
display(Image(url = image_url))