# Using CLIP Vectors to Find Similar Images in Set of Images

### Computing Platform Check GPU (CUDA) or CPU

In [None]:
import torch
if torch.cuda.is_available():
    device = "cuda"
else:
    print ('[WARNING] CUDA/GPU is not available! Compute-intensive scripts on this notebook will be run on CPU.')
    device =  "cpu"

### Requirements

In [None]:
import os
import json
import random
import numpy as np
from numpy import linalg
from tqdm import tqdm
import hashlib
import torch
import open_clip
from PIL import Image
from matplotlib import pyplot as plt

### Specify Data Source and Path for Writing JSON Files

In [None]:
# Note: Data folder must be in the form of normal folder (Unzipped) containing images in PNG or JPG format.
input_dir = './downloads/Tile_Generator_Genetic_Algo_V1_16x16-2023-23-2--16-01-20/'
# Path for resulting JSON files
img_list_json_path = './image_list.json'
clip_vectors_json_path = './image_clip_vectors.json'

### Define Functions and Create Hash and CLIP Model Object

In [None]:
def get_hash(file_path, hasher):
    # Get file hash
    with open(file_path, 'rb') as img_file:
        img_bytes = img_file.read()
    hasher.update(img_bytes)
    return hasher.hexdigest()

def get_clip(clip_model_type = 'ViT-L-14' , pretrained = 'openai'):
    # Get CLIP model
    clip_model, _, preprocess = open_clip.create_model_and_transforms(clip_model_type,pretrained=pretrained)
    device = "cuda" if torch.cuda.is_available() else "cpu"
    return clip_model , preprocess , device

def compute_clip(img, clip_model, preprocess, device):
    # Compute the CLIP vector
    img = preprocess(img).unsqueeze(0).to(device)
    return clip_model.encode_image(img).detach().numpy()

# Hash generator
hasher = hashlib.sha256()
# CLIP model
clip_model, preprocess, device = get_clip()

### Create Image List JSON and Image CLIP Vectors JSON Source from Input Images

In [None]:
def create_data_source_json(input_dir, img_list_json_path, clip_vectors_json_path):

    # Placeholder for image files paths
    img_list_json = {}
    # Placeholder for image clip vectors
    clip_vectors_json = {}

    print ('[INFO] Running on Data Source...')

    # Walking thru files
    for root, _, files in os.walk(input_dir):

        for file in tqdm(files):
            # Get file path
            file_path = f'{root}/{file}'
            # Check if file is png or jpg
            if os.path.splitext(file_path)[-1] == '.png' or os.path.splitext(file_path)[-1] == '.jpg':

                try:
                    # Get file hash
                    hash_id = get_hash(file_path, hasher)
                    # Compute CLIP Vector
                    img = Image.open(file_path)
                    clip_vector = compute_clip(img, clip_model, preprocess, device)

                    # Image list dictionary creation
                    img_list_json[hash_id]={'file_path':file_path, 'file_name':file}
                    # CLIP vectors dictionary creation
                    clip_vectors_json[hash_id]={'clip_vector':clip_vector.tolist()}

                except Exception as e:
                    print [f'[WARNING] Error when processing file: {e}']
                    return {}, {}

    # Writing to file
    with open (img_list_json_path, 'w') as file:
        json.dump(img_list_json, file, indent=4)    
    
    with open (clip_vectors_json_path, 'w') as file:
        json.dump(clip_vectors_json, file, indent=4)  
    
    # Number of images
    n_images = len(img_list_json)
    print (f'[INFO] Completed. Number of images: {n_images}')

    return img_list_json, clip_vectors_json

# Run the function
img_list_json, clip_vectors_json = create_data_source_json(input_dir, img_list_json_path, clip_vectors_json_path)

### Loading The Data Back From File List JSON and Image CLIP Vector JSON Files

In [None]:
# Specify again the location of JSON Files
img_list_json_path = './image_list.json'
clip_vectors_json_path = './image_clip_vectors.json'

# Get dictionary of image file paths
with open (img_list_json_path, 'r') as file:
    img_list_json = json.load(file)

# Get dictionary of image CLIP vectors
with open (clip_vectors_json_path, 'r') as file:
    clip_vectors_json = json.load(file)

### Get 1 Random 'Reference' Image

In [None]:
# List of hashes (keys in data_dict)
hash_list = list(img_list_json.keys())
# Get random hash
ref_file_hash = random.choice(hash_list)

# Reference Image File Path
ref_file_path = img_list_json[ref_file_hash]['file_path']
# Reference Image CLIP Vector
ref_file_clip_vector = np.array(clip_vectors_json[ref_file_hash]['clip_vector'][0])

### Calculate Dot Product Between Reference Image CLIP Vector and All Other Images CLIP Vectors

In [None]:
# Similarity function definition
def get_hashes_with_similar_clip(ref_file_hash, ref_file_clip_vector, clip_vectors_json, n_top_similar):
    '''
    Return list containing pair tupple of dot product and hash with the following structure
    [(<dot_product>, <sample_image_hash>), ...]
    '''
    dot_products = []
    # Get top n similar images based on dot products score
    n_top_similar = 8

    for key in clip_vectors_json.keys():
        if key == ref_file_hash:
            # If it is an hash of reference image then ignore the clip vector
            continue

        '''Calculate dot product'''
        # Normalize reference vector and sample vector
        norm_ref_file_clip_vector = ref_file_clip_vector / linalg.norm(ref_file_clip_vector)
        sample_clip_vector = np.array(clip_vectors_json[key]['clip_vector'][0])
        norm_sample_clip_vector = sample_clip_vector / linalg.norm(sample_clip_vector)
        
        # Calculate dot product
        dot_product = np.dot(norm_ref_file_clip_vector, norm_sample_clip_vector)
        
        # Appending dot product result to list
        sample_hash = key
        dot_products.append((dot_product, sample_hash))

    dot_products.sort(reverse=True)

    return dot_products[:n_top_similar]

# Specify top n images to display
n_top_similar = 8

# Run the function
top_similar_images = get_hashes_with_similar_clip(ref_file_hash, ref_file_clip_vector, clip_vectors_json, n_top_similar)
print (top_similar_images)

### Show Reference Image

In [None]:
# Show reference image
plt.imshow(Image.open(ref_file_path))

### Show Top Similar Images

In [None]:
# Show 'n_top_similar' most similar images. Similarity ranking: from left to right.

fig, ax = plt.subplots(1, n_top_similar, figsize = (20,20))

print ('[INFO] Showing Similar Images. Similarity ranking: from left to right.')
i=0
for item in tqdm(top_similar_images):
    '''
    item has the following structure
    [[<dot_product>, <sample_image_hash>, ...]
    '''
    sample_image_hash = item[1]
    file_path = img_list_json[sample_image_hash]['file_path']
    ax[i].imshow(Image.open(file_path))
    i+=1