# `clip_retrieval.clip_client`

This python module allows you to query a backend remote via its exposed REST api.

## Install

In [None]:
%pip install clip-retrieval img2dataset

## Setup

In [None]:
from IPython.display import Image, display

from clip_retrieval.clip_client import ClipClient, Modality

IMAGE_BASE_URL = "https://github.com/rom1504/clip-retrieval/raw/main/tests/test_clip_inference/test_images/"


def log_result(result):
    id, caption, url, similarity = (
        result["id"],
        result["caption"],
        result["url"],
        result["similarity"],
    )
    print(f"id: {id}")
    print(f"caption: {caption}")
    print(f"url: {url}")
    print(f"similarity: {similarity}")
    display(Image(url=url, unconfined=True))


client = ClipClient(
    url="https://knn.laion.ai/knn-service",
    indice_name="laion5B-L-14",
    aesthetic_score=9,
    aesthetic_weight=0.5,
    modality=Modality.IMAGE,
    num_images=10,
)

## Query by text

In [None]:
cat_results = client.query(text="an image of a cat")
log_result(cat_results[0])

## Query by image

In [None]:
beach_results = client.query(
    image="https://github.com/rom1504/clip-retrieval/raw/main/tests/test_clip_inference/test_images/321_421.jpg"
)
log_result(beach_results[0])

## Query by embedding

In [None]:
import clip  # pylint: disable=import-outside-toplevel
import torch

model, preprocess = clip.load("ViT-L/14", device="cpu", jit=True)

import io
import urllib

import numpy as np


def download_image(url):
    urllib_request = urllib.request.Request(
        url,
        data=None,
        headers={"User-Agent": "Mozilla/5.0 (X11; Ubuntu; Linux x86_64; rv:72.0) Gecko/20100101 Firefox/72.0"},
    )
    with urllib.request.urlopen(urllib_request, timeout=10) as r:
        img_stream = io.BytesIO(r.read())
    return img_stream


def normalized(a, axis=-1, order=2):
    l2 = np.atleast_1d(np.linalg.norm(a, order, axis))
    l2[l2 == 0] = 1
    return a / np.expand_dims(l2, axis)


def get_text_emb(text):
    with torch.no_grad():
        text_emb = model.encode_text(clip.tokenize([text], truncate=True).to("cpu"))
        text_emb /= text_emb.norm(dim=-1, keepdim=True)
        text_emb = text_emb.cpu().detach().numpy().astype("float32")[0]
    return text_emb


from PIL import Image as pimage


def get_image_emb(image_url):
    with torch.no_grad():
        image = pimage.open(download_image(image_url))
        image_emb = model.encode_image(preprocess(image).unsqueeze(0).to("cpu"))
        image_emb /= image_emb.norm(dim=-1, keepdim=True)
        image_emb = image_emb.cpu().detach().numpy().astype("float32")[0]
        return image_emb

In [None]:
red_tshirt_text_emb = get_text_emb("red tshirt")
red_tshirt_results = client.query(embedding_input=red_tshirt_text_emb.tolist())
log_result(red_tshirt_results[0])

In [None]:
blue_dress_image_emb = get_image_emb(
    "https://rukminim1.flixcart.com/image/612/612/kv8fbm80/dress/b/5/n/xs-b165-royal-blue-babiva-fashion-original-imag86psku5pbx2g.jpeg?q=70"
)
blue_dress_results = client.query(embedding_input=blue_dress_image_emb.tolist())
log_result(blue_dress_results[0])

In [None]:
red_tshirt_text_emb = get_text_emb("red tshirt")
blue_dress_image_emb = get_image_emb(
    "https://rukminim1.flixcart.com/image/612/612/kv8fbm80/dress/b/5/n/xs-b165-royal-blue-babiva-fashion-original-imag86psku5pbx2g.jpeg?q=70"
)
mean_emb = normalized(red_tshirt_text_emb + blue_dress_image_emb)[0]
mean_results = client.query(embedding_input=mean_emb.tolist())
log_result(mean_results[0])

## Download and format a dataset from the results of a query

If you have some images of your own, you can query each one and collect the results into a custom dataset (a small subset of LAION-5B)

In [None]:
# Create urls from known images in repo
import json

from tqdm import tqdm

test_images = [
    f"{IMAGE_BASE_URL}{image}"
    for image in [
        "123_456.jpg",
        "208_495.jpg",
        "321_421.jpg",
        "389_535.jpg",
        "416_264.jpg",
        "456_123.jpg",
        "524_316.jpg",
    ]
]

# Re-initialize client with higher num_images
client = ClipClient(url="https://knn.laion.ai/knn-service", indice_name="laion5B-L-14", num_images=40)

# Run one query per image
combined_results = []
for image in tqdm(test_images):
    combined_results.extend(client.query(image=image))

# Save results to json file
with open("search-results.json", "w") as f:
    json.dump(combined_results, f)

In [None]:
!img2dataset "search-results.json" --input_format="json" --caption_col "caption" --output_folder="laion-enhanced-dataset" --resize_mode="no" --output_format="files"

In [None]:
import os

print(f"Done! Download/copy the contents of {os.getcwd()}/laion-enhanced-dataset/")
!realpath laion-enhanced-dataset/