<a href="https://colab.research.google.com/github/Konerusudhir/machine_learning_exercises/blob/master/SemanticSearch0_4.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

# **MP NFTs Semantic Search using Text and Images**

I fetch MakersPlace NFTs(10K) and Index them using Faiss. Search is performed using Query embedding on the Text index or Image embedding on Image index. Below are the individual steps.

1.   Fetch Makers Place Contract NFT Token ids(n ...m) meta data using Alchemy 
     API. Extract previwe image URLs from the response
2.   Fetch Title text for each NFT. Parallell execution is used.
3.   Download Clip Model(openai/clip-vit-large-patch14) from HuggingFace and instantiate TextModel and VIsion Model.
4.   Create Title Text Index using Faiss
5.   Generate text embeddings for Titles using Clip Text Model and add them to text index index
6.   Load pre generated Image index
7.   Build Gradio App for searching NFTs
8.   Load Web App using Public URL and have fun

## Acknowledgements
1. Text index is built using NFT Titles only. Limited text to alpha numeric characters only. 
2. This is setup is not evaluated using curated/benchmark dataset to verify
   accuracy
3. Index generation and Index search is not on GPU. Only Image embedding generation is in GPU
4. GPU Memory clean up is done manually. Need fine tuning to avoid OOM errors
5. Search results in unintended images like NSFW images




In [1]:
%%capture
!pip install gradio
!pip install transformers
!pip install faiss-gpu
!pip install torch
!pip install Pillow
!pip install matplotlib
!pip install nltk

In [2]:
import os
import shutil
import math 
import glob
import json
import pickle
import requests
import time
import re
import string

import nltk
from nltk.stem import PorterStemmer
from nltk.stem.wordnet import WordNetLemmatizer
import spacy

from concurrent.futures import ThreadPoolExecutor
import numpy as np
from numpy import random
from PIL import Image
from IPython import display
import matplotlib.pyplot as plt

import faiss

DIMENSIONS = 768
GATEWAY_URL = "https://eth-mainnet.g.alchemy.com/nft/v2/rBshNbJGutTbf2ACdQ9XyGhhc1uSolds/getNFTMetadataBatch"
IPFS_GATEWAY = 'https://ipfsgateway.makersplace.com/ipfs/'
MP_CONTRACT_ADDRESS = "0x2963ba471e265e5f51cafafca78310fe87f8e6d1"
SUPPORTED_CONTENT_TYPES = ['image/jpeg','image/png','image/gif']
IPFS_IMAGE_IDS_FILE_NAME = "ipfs_image_ids.pickle"
INDEX_FOLDER = "./indexes"
IMAGES_FOLDER = "./images"
IPFS_IMAGE_IDS_PATH = os.path.join(INDEX_FOLDER, IPFS_IMAGE_IDS_FILE_NAME)
MIN_TOKEN_ID = 1
MAX_TOKEN_ID = 20000
GATEWAY_QUERY_BATCH_SIZE = 100
VISION_MODEL_INPUT_BATCH_SIZE = 8
TEXT_MODEL_INPUT_BATCH_SIZE = 10
SEARCH_RESULTS_DISPLAY_COUNT = 8
LOG_DISPLAY_THRESHOLD = 1000
RANDOM_SEED = 7
np.random.seed(RANDOM_SEED)

# Load stop words
nltk.download('stopwords')

# Load spacy
nlp = spacy.load('en_core_web_sm')

[nltk_data] Downloading package stopwords to /root/nltk_data...
[nltk_data]   Package stopwords is already up-to-date!


In [3]:
def clean_directories():
  shutil.rmtree(IMAGES_FOLDER, ignore_errors=True)
  shutil.rmtree(INDEX_FOLDER, ignore_errors=True)

def create_directories():
  if not os.path.exists(IMAGES_FOLDER):
    os.mkdir(IMAGES_FOLDER) 
  if not os.path.exists(INDEX_FOLDER):
    os.mkdir(INDEX_FOLDER)

# clean_directories()
create_directories()

# Fetch NFTs metadata

In [4]:
%%time
class FetchClass:
    def __init__(self):
        self.image_duplicates_count = 0
        self.text_descriptions = np.empty((0,3), str)
        self.ipfs_image_ids = set()

    def fetch_image_urls(self, min_token_id):
        max_toke_id = min_token_id + GATEWAY_QUERY_BATCH_SIZE
        token_id_requests = [] 
        for i in range(min_token_id, max_toke_id):
            token_id_requests.append(
                {
                    "contractAddress": MP_CONTRACT_ADDRESS,
                    "tokenId": f"{i}",
                    "tokenType": "ERC721"
                }
            )

        payload = {
            "tokens": token_id_requests,
            "refreshCache": False
        }
        headers = {
            "accept": "application/json",
            "content-type": "application/json"
        }

        responses = requests.post(GATEWAY_URL, json=payload, headers=headers)
        nfts_metadata = json.loads(responses.text)

        for nft_metadata in nfts_metadata:
            imageUrl = nft_metadata['metadata'].get('imageUrl', None)
            description = nft_metadata.get('title', None)
            tokenId = nft_metadata['id'].get('tokenId', None)
            if imageUrl is not None and description is not None:
                ipfs_id = imageUrl.split('/')[-1]
                if ipfs_id not in self.ipfs_image_ids:
                    self.ipfs_image_ids.add(ipfs_id)
                    self.text_descriptions = np.append(
                        self.text_descriptions, 
                        np.array([[ipfs_id,description,tokenId]]), axis=0)
                else:
                    self.image_duplicates_count += 1
                    if self.image_duplicates_count%LOG_DISPLAY_THRESHOLD == 0:
                        print(f"Found {self.image_duplicates_count} duplicates")
        
np_file_prefix = "mp_nft_data_np_array_12k.pickle"
np_file_name = f"{np_file_prefix}.npy"
np_file_path = os.path.join(INDEX_FOLDER, np_file_name)

if os.path.exists(np_file_path):
    text_descriptions = np.load(np_file_path)
else:
    fetch_class = FetchClass()
    for batch_start_index in range(
        MIN_TOKEN_ID, MAX_TOKEN_ID, GATEWAY_QUERY_BATCH_SIZE):
        
        fetch_class.fetch_image_urls(batch_start_index)
        
        if (batch_start_index - 1)%LOG_DISPLAY_THRESHOLD == 0:
            print(f"Fetched batch from {batch_start_index}")
        
        time.sleep(0.1)    

    text_descriptions = fetch_class.text_descriptions
    np.save(np_file_path, text_descriptions)

print(f"""
    Text Array shape: {text_descriptions.shape}
    """)


    Text Array shape: (12306, 3)
    
CPU times: user 8.17 ms, sys: 70.9 ms, total: 79.1 ms
Wall time: 79.6 ms


# Clean NFT title Text

In [5]:
def clean_string(text, stem="None"):

    final_string = ""

    text = re.sub(r"[^a-zA-Z0-9 ]", "", text)

    # Make lower
    text = text.lower()

    # Remove line breaks
    text = re.sub(r'\n', '', text)

    # Remove puncuation
    translator = str.maketrans('', '', string.punctuation)
    text = text.translate(translator)

    # Remove stop words
    text = text.split()
    useless_words = nltk.corpus.stopwords.words("english")
    useless_words = useless_words + ['hi', 'im']

    text_filtered = [word for word in text if not word in useless_words]

    # Remove numbers
    # text_filtered = [re.sub(r'\w*\d\w*', '', w) for w in text_filtered]

    # Stem or Lemmatize
    if stem == 'Stem':
        stemmer = PorterStemmer() 
        text_stemmed = [stemmer.stem(y) for y in text_filtered]
    elif stem == 'Lem':
        lem = WordNetLemmatizer()
        text_stemmed = [lem.lemmatize(y) for y in text_filtered]
    elif stem == 'Spacy':
        text_filtered = nlp(' '.join(text_filtered))
        text_stemmed = [y.lemma_ for y in text_filtered]
    else:
        text_stemmed = text_filtered
    
    partial_string = text_stemmed[0:60]
    final_string = ' '.join(partial_string)

    return final_string



# 3 - Download Clip Model

In [6]:
%%time
from transformers import CLIPTokenizer, CLIPTextModelWithProjection, CLIPProcessor, CLIPVisionModelWithProjection, TFCLIPTextModel, TFCLIPVisionModel


clip_model_id = "openai/clip-vit-large-patch14"

text_model = CLIPTextModelWithProjection.from_pretrained(clip_model_id)
tokenizer = CLIPTokenizer.from_pretrained(clip_model_id)

vision_model = CLIPVisionModelWithProjection.from_pretrained(clip_model_id)
processor = CLIPProcessor.from_pretrained(clip_model_id)

# device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu")
# vision_model.to(device)

Some weights of the model checkpoint at openai/clip-vit-large-patch14 were not used when initializing CLIPTextModelWithProjection: ['vision_model.encoder.layers.6.mlp.fc2.weight', 'vision_model.encoder.layers.15.self_attn.q_proj.weight', 'vision_model.encoder.layers.7.mlp.fc2.bias', 'vision_model.encoder.layers.1.layer_norm2.weight', 'vision_model.encoder.layers.3.mlp.fc1.bias', 'vision_model.encoder.layers.20.self_attn.v_proj.weight', 'vision_model.encoder.layers.21.mlp.fc2.bias', 'vision_model.encoder.layers.3.self_attn.q_proj.bias', 'vision_model.encoder.layers.10.layer_norm2.weight', 'vision_model.encoder.layers.7.self_attn.q_proj.bias', 'vision_model.encoder.layers.9.self_attn.k_proj.weight', 'vision_model.encoder.layers.2.self_attn.v_proj.weight', 'vision_model.encoder.layers.2.self_attn.out_proj.bias', 'vision_model.encoder.layers.0.self_attn.k_proj.bias', 'vision_model.encoder.layers.11.layer_norm1.bias', 'vision_model.encoder.layers.19.mlp.fc2.bias', 'vision_model.encoder.laye

CPU times: user 7.41 s, sys: 6.32 s, total: 13.7 s
Wall time: 18.4 s


# Build Title Index

In [7]:
%%time
TEXT_MODEL_INPUT_BATCH_SIZE = 200
storage = "Flat"
index_name = f"IDMap,{storage}"
text_index = faiss.index_factory(DIMENSIONS, index_name)
descriptions_ids_map = {}
banned_token_ids = {
    1185 # hash special charaters which is crashing Tokenizer
}

def get_text_embeds(queries):
    inputs = tokenizer(queries, padding=True, return_tensors="pt")
    outputs = text_model(**inputs)
    return outputs.text_embeds.detach().numpy()

def build_text_index():
  
    text_batch = []
    ipfs_ids = []
    embeds_count = 0
    batches = np.array_split(text_descriptions, TEXT_MODEL_INPUT_BATCH_SIZE)
    
    for text_descriptions_batch in batches:    
        cleaned_text = []
        ipfs_ids_hashes = []
        title_text_batch = text_descriptions_batch[:,1].tolist()
        batch_length = len(title_text_batch)

        for i in range(batch_length):
            token_id = int(text_descriptions_batch[i][2])
            if token_id not in banned_token_ids:
                ipfs_id = text_descriptions_batch[i][0]
                ipfs_id_hash = hash(ipfs_id)
                ipfs_ids_hashes.append(ipfs_id_hash)
                descriptions_ids_map[ipfs_id_hash] = ipfs_id
                cleaned_string = clean_string(title_text_batch[i])
                cleaned_text.append(cleaned_string)
            
        if len(cleaned_text) > 0:
            try:
                text_description_embeds = get_text_embeds(cleaned_text)                            
                text_index.add_with_ids(text_description_embeds, np.array(ipfs_ids_hashes))
                embeds_count+=batch_length      
                if embeds_count%2 == 0:
                    print(f"Created embeds for {embeds_count} descriptions")
            except Exception as e:        
                print(f"""
                Text : {text_descriptions_batch}  - cleaned: {cleaned_text}
                IPFS Hashes: {ipfs_id_hash} 
                TokenIds:    {token_id}
                """)


text_index_file_prefix = index_name.replace(',', '_')
text_index_file_name = f"{text_index_file_prefix}_text.index"
text_index_file_path = os.path.join(INDEX_FOLDER, text_index_file_name)
if os.path.exists(text_index_file_path):
    text_index = faiss.read_index(text_index_file_path)
else:    
    build_text_index()
    faiss.write_index(text_index, text_index_file_path)


text_id_map_file_name = f"{text_index_file_prefix}_text_ids.pickle"
text_id_map_path = os.path.join(INDEX_FOLDER, text_id_map_file_name)
if os.path.exists(text_id_map_path):
    with open(text_id_map_path, 'rb') as f:
        descriptions_ids_map = pickle.load(f)
else:
    with open(text_id_map_path, 'wb') as f:
        pickle.dump(descriptions_ids_map, f)


print(f"Text Index Size: {text_index.ntotal}")
print(f"Text Ids Size: {len(descriptions_ids_map)}")

Text Index Size: 12305
Text Ids Size: 12305
CPU times: user 55.9 ms, sys: 22.3 ms, total: 78.3 ms
Wall time: 132 ms


# Utils

In [8]:
%%time

def load_resized_image(img_path, max_width = 300, max_height = 300):
    try:
        if 'http' in img_path:
            # response = requests.get(img_path, stream=True)
            img = Image.open(requests.get(img_path, stream=True).raw)
        else:    
            img = Image.open(img_path).convert('RGB')
        width, height = img.size
        if width > max_width or height > max_height:
            img.thumbnail((max_width, max_height))
        img = np.asarray(img)
        return img    
    except Exception as e:
        print(f"IPFS ID:{img_path.split('/')[-1]} - {e}") 
        return None

def download_images(ipfs_ids):

    def fetch_image(ipfs_image_id):
        image_url = os.path.join(IPFS_GATEWAY,ipfs_image_id) 
        image_local_path = os.path.join(IMAGES_FOLDER,ipfs_image_id)

        # print(f"{image_url} -- {file_name} -- {image_local_path}")
        
        if not os.path.exists(image_local_path):    
            response = requests.get(image_url)
            content_type = response.headers.get('content-type')
            if response.status_code and content_type in SUPPORTED_CONTENT_TYPES:
                fp = open(image_local_path, 'wb')
                fp.write(response.content)
                fp.close()
            else:
                print(f"HTTP Code:{response.status_code} - {content_type} - IPFS ID:{ipfs_image_id}")

    with ThreadPoolExecutor(max_workers=32) as executor:
        executor.map(fetch_image, ipfs_ids)

def search_descriptions_using_text(local_text_index, text_strings):
    search_results = []
    searh_embeds = get_text_embeds(text_strings)    
    _, description_ids = local_text_index.search(searh_embeds, SEARCH_RESULTS_DISPLAY_COUNT)  
    return description_ids 


def read_ipfs_image_ids(image_ids_path):
    if os.path.exists(image_ids_path):
        with open(image_ids_path, 'rb') as f:
            return pickle.load(f)
    else:
        return set()

def get_image_embeds(images):
  inputs = processor(images=images, return_tensors="pt")
  outputs = vision_model(**inputs)
  return outputs.image_embeds

def search_images_using_images(image_index, search_images):
    search_results = []
    # for search_image in search_images:        
    try:
        searh_embeds = get_image_embeds(search_images).detach().numpy()        
        _, image_ids = image_index.search(searh_embeds, SEARCH_RESULTS_DISPLAY_COUNT)
        search_results.extend(image_ids[0])
    except Exception as e:
        print(f"Bad Image IPFS ID:{search_images} - {e}") 
    return search_results

CPU times: user 10 µs, sys: 0 ns, total: 10 µs
Wall time: 14.3 µs


# Gradio Image search App

In [11]:
import gradio as gr

# Download some images from ipfs ids to boot strap image search App
loaded_ipfs_image_ids = list(read_ipfs_image_ids(IPFS_IMAGE_IDS_PATH))
# download_images(loaded_ipfs_image_ids[:30])

# Load Image index
image_index_prefix = index_name.replace(',', '_')
image_index_file_name = f"{image_index_prefix}.index"
index_path = os.path.join(INDEX_FOLDER, image_index_file_name)
loaded_image_index = faiss.read_index(index_path)
# Load Image Id map
id_map_file_name = f"{image_index_prefix}_image_ids.pickle"
id_map_path = os.path.join(INDEX_FOLDER, id_map_file_name)
with open(id_map_path, 'rb') as f:
    loaded_image_ids_map = pickle.load(f)

print(f"Text Index size: {text_index.ntotal}")
print(f"Image Index size: {loaded_image_index.ntotal}")
print(f"Text Ids size: {len(descriptions_ids_map)}")
print(f"Image Ids size: {len(loaded_image_ids_map)}")

def search_images_by_image(query_image):
    return search_images('', query_image)

def search_images(query_text, query_image):
    images_to_download = []
    # print(f"Received {query_text}:{query_image}")
    if len(query_text) > 0:
        ids_map = descriptions_ids_map
        search_results = search_descriptions_using_text(text_index, [query_text]).flatten()
        for image_hash in search_results:
            images_to_download.append(ids_map[image_hash])
        
    elif query_image is not None:
        ids_map = loaded_image_ids_map
        search_results = search_images_using_images(loaded_image_index, [query_image])
        # print(f"Search results: {search_results}")
        for image_hash in search_results:
            ipfs_id = ids_map.get(image_hash, None)
            # if ipfs_id is not None:
            images_to_download.append(ids_map[image_hash])
    # print(f"Images to Download: {images_to_download}")        
    download_images(images_to_download)
    # print(f"search_results: {search_results}")
    # print(f"Images to download: {images_to_download}")
    image_objects = []
    for image_hash in search_results:
        image_id = ids_map[image_hash]
        local_path = os.path.join(IMAGES_FOLDER, image_id)
        image_objects.append(load_resized_image(local_path))
    return image_objects

with gr.Blocks() as demo:
    with gr.Column(scale=1):
        query_text = gr.Textbox(label="query_text", value="universe")
        query_image = gr.Image(label="query_image")
        search_btn = gr.Button("Search")        
    

    with gr.Row():
        with gr.Column(scale=1):
            result_0_image = gr.Image(type="pil", label="result_0_image")
            result_0_html =  gr.HTML(label="result_0_html", show_label = True)
            result_0_btn = gr.Button("Related Images")    
        with gr.Column(scale=1):
            result_1_image = gr.Image(type="pil", label="result_1_image")
            result_1_btn = gr.Button("Related Images")  
        with gr.Column(scale=1):
            result_2_image = gr.Image(type="pil", label="result_2_image")
            result_2_btn = gr.Button("Related Images")  
    with gr.Row():    
        with gr.Column(scale=1):
            result_3_image = gr.Image(type="pil", label="result_3_image")
            result_3_btn = gr.Button("Related Images")  
        with gr.Column(scale=1):
            result_4_image = gr.Image(type="pil", label="result_4_image")
            result_4_btn = gr.Button("Related Images")  
        with gr.Column(scale=1):
            result_5_image = gr.Image(type="pil", label="result_5_image")
            result_5_btn = gr.Button("Related Images")  

    inputs = [query_text, query_image]    
    outputs = [
        result_0_image, 
        result_1_image, 
        result_2_image, 
        result_3_image, 
        result_4_image,
        result_5_image]

    search_btn.click(fn=search_images, inputs=inputs, outputs=outputs)
    
    result_0_btn.click(fn=search_images_by_image, 
                       inputs=result_0_image, outputs=outputs)
    result_1_btn.click(fn=search_images_by_image, 
                       inputs=result_1_image, outputs=outputs)
    result_2_btn.click(fn=search_images_by_image, 
                       inputs=result_2_image, outputs=outputs)
    result_3_btn.click(fn=search_images_by_image, 
                       inputs=result_3_image, outputs=outputs)
    result_4_btn.click(fn=search_images_by_image, 
                       inputs=result_4_image, outputs=outputs)
    result_5_btn.click(fn=search_images_by_image, 
                       inputs=result_5_image, outputs=outputs)
    

demo.launch(share=True, debug=True,)

Text Index size: 12305
Image Index size: 10957
Text Ids size: 12305
Image Ids size: 10957
Colab notebook detected. This cell will run indefinitely so that you can see errors and logs. To turn off, set debug=False in launch().
Running on public URL: https://7428be80-d0ec-47f4.gradio.live

This share link expires in 72 hours. For free permanent hosting and GPU upgrades (NEW!), check out Spaces: https://huggingface.co/spaces


Keyboard interruption in main thread... closing server.
Killing tunnel 127.0.0.1:7860 <> https://7428be80-d0ec-47f4.gradio.live




In [12]:
demo.close()

Closing server running on port: 7860


In [13]:
# file_name = "mp_nft_data_np_array_12k.pickle"
# file_path = os.path.join(INDEX_FOLDER, file_name)
# np.save(file_path, fetch_class.text_descriptions)

shutil.make_archive(f"{INDEX_FOLDER}", 'tar', INDEX_FOLDER)




'/content/indexes.tar'

In [None]:
# gpu_info = !nvidia-smi
# gpu_info = '\n'.join(gpu_info)
# if gpu_info.find('failed') >= 0:
#   print('Not connected to a GPU')
# else:
#   print(gpu_info)