In [None]:
import os
import cv2
import torch
from transformers import AutoProcessor, AutoModelForZeroShotImageClassification, AutoModel
import numpy as np
import matplotlib.pyplot as plt
from loadimg import load_img
from datasets import load_dataset
from PIL import Image
from tqdm import tqdm

from src.tools import art_cropper
from src.transformations import image_transforms_no_tensor
from torch.nn.functional import cosine_similarity
import matplotlib.pyplot as plt

In [None]:
dataset = load_dataset("imagefolder", data_dir="./data")

In [None]:
dataset.push_to_hub(f"HichTala/yugioh")

In [None]:
dataset = load_dataset("HichTala/yugioh")

In [None]:
device = 'cuda' if torch.cuda.is_available() else 'cpu'

processor = AutoProcessor.from_pretrained("google/vit-huge-patch14-224-in21k")
model = AutoModel.from_pretrained("google/vit-huge-patch14-224-in21k", device_map=device)


In [None]:
def embed(batch):
    pixel_values = processor(images=batch["image"], return_tensors="pt")['pixel_values']
    pixel_values = pixel_values.to(device)
    img_emb = model.get_image_features(pixel_values)
    batch["embeddings"] = img_emb
    return batch


embedded_dataset = dataset.map(embed, batched=True, batch_size=16)

In [None]:
embedded_dataset.push_to_hub("HichTala/yugioh-embeddings")


In [None]:
dataset = load_dataset("HichTala/yugioh-embeddings", split="train")

In [None]:
dataset = dataset.add_faiss_index("embeddings")

In [None]:
def search(query: str, k: int = 4):
    """a function that embeds a new image and returns the most probable results"""

    pixel_values = processor(images=query, return_tensors="pt")['pixel_values']  # embed new image
    pixel_values = pixel_values.to(device)
    img_emb = model(pixel_values).pooler_output[0] # because it's a single element
    img_emb = img_emb.cpu().detach().numpy()  # convert to numpy because the datasets library does not support torch vectors

    scores, retrieved_examples = dataset.get_nearest_examples(  # retrieve results
        "embeddings", img_emb,  # compare our new embedded image with the dataset embeddings
        k=k  # get only top k results
    )

    return retrieved_examples

In [None]:
def image_resize(image, width = None, height = None, inter = cv2.INTER_AREA):
    # initialize the dimensions of the image to be resized and
    # grab the image size
    dim = None
    (h, w) = image.shape[:2]
    # if both the width and height are None, then return the
    # original image
    if width is None and height is None:
        return image
    # check to see if the width is None
    if width is None:
        # calculate the ratio of the height and construct the
        # dimensions
        r = height / float(h)
        dim = (int(w * r), height)
    # otherwise, the height is None
    else:
        # calculate the ratio of the width and construct the
        # dimensions
        r = width / float(w)
        dim = (width, int(h * r))
    # resize the image
    resized = cv2.resize(image, dim, interpolation = inter)
    # return the resized image
    return resized

In [None]:
query = "queries/rh_0.jpg"

In [None]:
image = cv2.imread(query)
image = image_resize(image, width=536)

# image = load_img(query).resize(((536, 782)))
image = Image.fromarray(cv2.cvtColor(image, cv2.COLOR_BGR2RGB))
image

In [None]:
image = load_img(query)
image

In [None]:
image = load_img(query)
image = art_cropper(image)
image

In [None]:
nb_rows = 2
nb_image = nb_rows ** 2


In [None]:
retrieved_examples = search(image, nb_image)
f, axarr = plt.subplots(nb_rows, nb_rows)
for index in range(nb_image):
    i, j = index // nb_rows, index % nb_rows
    # axarr[i,j].set_title(retrieved_examples["text"][index])
    axarr[i, j].imshow(retrieved_examples["image"][index])
    axarr[i, j].axis('off')
plt.show()

In [None]:
DEVICE = 'cuda' if torch.cuda.is_available() else 'cpu'

In [None]:
def infer(image):
  inputs = processor(image, return_tensors="pt").to(DEVICE)
  outputs = model(**inputs)
  return outputs.pooler_output

In [None]:
dataset[0]

In [None]:
query1 = "queries/rh_0.jpg"
query2 = "queries/rh_0_enhanced.png"
query3 = "queries/image.png"

In [None]:
image_real = Image.open(query1).crop((13, 30, 81, 97))
image_2 = Image.open(query2).crop((25, 58, 161, 194))
image_3 = Image.open(query3).crop((53, 118, 322, 385))

In [None]:
cv2.imshow("", cv2.imread(query3))
cv2.waitKey(0)
cv2.destroyAllWindows()

In [None]:
image_real = image_transforms_no_tensor(image_real)
image_2 = image_transforms_no_tensor(image_2)
image_3 = image_transforms_no_tensor(image_3)

In [None]:
embed_real = infer(image_real)
embed_2 = infer(image_2)
embed_3 = infer(image_3)

In [None]:
cv2.imshow("", cv2.imread(query1))
cv2.waitKey(0)
cv2.destroyAllWindows()

In [None]:
similarities = {}
similarities2 = {}
similarities3 = {}
with tqdm(total=13701, desc="Augmenting Dataset", colour='cyan') as pbar:
    for subdir, dirs, files in os.walk("./data"):
        for file in files:
            pbar.update(1)
            image_gen = Image.open(os.path.join(subdir, file))
            embed_gen = infer(image_gen)
            similarities[subdir] = cosine_similarity(embed_real, embed_gen, dim=1).cpu().detach().item()
            similarities2[subdir] = cosine_similarity(embed_2, embed_gen, dim=1).cpu().detach().item()
            similarities[subdir] = cosine_similarity(embed_3, embed_gen, dim=1).cpu().detach().item()

In [None]:
similarity_score = cosine_similarity(embed_real, embed_gen, dim=1)
print(similarity_score)

In [None]:
similarity_score = cosine_similarity(embed_real, embed_gen, dim=1)
print(similarity_score)

In [None]:
similarities_sorted = sorted(similarities.items(), key=lambda x: x[1], reverse=True)

In [None]:
similarities = {k: v for k, v in similarities}

In [None]:
similarities2_sorted = sorted(similarities2.items(), key=lambda x: x[1], reverse=True)
similarities3_sorted = sorted(similarities3.items(), key=lambda x: x[1], reverse=True)


In [None]:
similarities

In [None]:
similarities['./data/Harpie-Queen-0-75064463']