In [None]:
import pandas as pd 
import numpy as np 
import glob 
from transformers import ViTFeatureExtractor, ViTModel
from PIL import Image
import cv2 


In [None]:
import torch 
feature_extractor = ViTFeatureExtractor.from_pretrained('google/vit-base-patch16-224-in21k')

In [None]:
test = cv2.imread("/Users/menghang/Desktop/ml-dev/ml-replicate/vector-db/Alejandro_Toledo_0037.jpg")
model_name = 'google/vit-base-patch16-224-in21k'
model_transformer = ViTModel.from_pretrained(model_name)

inputs = feature_extractor(images=test, return_tensors="pt")

In [None]:
arch = "transformer"

In [None]:
with torch.no_grad():
    outputs = model_transformer(**inputs)

# The extracted features are in the last hidden state
features = outputs.last_hidden_state

print(features.shape) 

In [None]:
features = features[:,0, :]

In [None]:
data = {
    "brain": [], 
    "butterfly": [],
}



In [None]:
for i in glob.glob("caltech-101/101_ObjectCategories/brain/*"):
    data["brain"].append(i)


In [None]:
for i in glob.glob("caltech-101/101_ObjectCategories/butterfly/*"):
    data["butterfly"].append(i)

In [None]:
data["brain"] = data["brain"][:10]
data["butterfly"] = data["butterfly"][:10]


In [None]:
embedding_data = []
from torchvision.models import resnet50, ResNet50_Weights
import torch.nn as nn 
model = resnet50( weights=ResNet50_Weights.IMAGENET1K_V2)

model = nn.Sequential(*list(model.children())[:-1])


In [None]:
print(data["butterfly"][-1])

In [None]:
import cv2
for brain_path in data["brain"]: 
    img = cv2.imread(brain_path)
    img = cv2.resize(img, (224,224))
    

    if arch == "transformer": 
        img = feature_extractor(images=img, return_tensors="pt")
        features = model_transformer(**img).last_hidden_state
        features = features[:, 0, :][0].tolist()
        embedding_data.append(features)
    else:
        transposed_img = img.transpose(2, 0, 1)
        features = model(torch.Tensor([transposed_img]))
        flattened_features = features.view(features.size(0), -1)[0].tolist()
        embedding_data.append(flattened_features)

In [None]:
for bf_path in data["butterfly"]: 
    img = cv2.imread(bf_path)
    img = cv2.resize(img, (224,224))
    

    if arch == "transformer": 
        img = feature_extractor(images=img, return_tensors="pt")
        features = model_transformer(**img).last_hidden_state
        features = features[:, 0, :][0].tolist()
        embedding_data.append(features)
    else:
        transposed_img = img.transpose(2, 0, 1)
        features = model(torch.Tensor([transposed_img]))
        flattened_features = features.view(features.size(0), -1)[0].tolist()
        embedding_data.append(flattened_features)

In [None]:
import chromadb
client = chromadb.Client()
collection = client.get_or_create_collection("image-searches")

In [None]:
collection.add(embeddings=embedding_data, ids=[f"id-{x}" for x in range(len(embedding_data))])


['caltech-101/101_ObjectCategories/brain/image_0032.jpg',
 'caltech-101/101_ObjectCategories/brain/image_0026.jpg',
 'caltech-101/101_ObjectCategories/brain/image_0027.jpg',
 'caltech-101/101_ObjectCategories/brain/image_0033.jpg',
 'caltech-101/101_ObjectCategories/brain/image_0019.jpg',
 'caltech-101/101_ObjectCategories/brain/image_0025.jpg',
 'caltech-101/101_ObjectCategories/brain/image_0031.jpg',
 'caltech-101/101_ObjectCategories/brain/image_0030.jpg',
 'caltech-101/101_ObjectCategories/brain/image_0024.jpg',
 'caltech-101/101_ObjectCategories/brain/image_0018.jpg']

In [79]:
img = cv2.imread("caltech-101/101_ObjectCategories/brain/image_0032.jpg")
img = cv2.resize(img, (224,224))
inputs = feature_extractor(images=img, return_tensors="pt")
features = model_transformer(**inputs).last_hidden_state
features = features[:, 0, :][0].tolist()




In [None]:
data["butterfly"]

In [None]:
img = cv2.imread("caltech-101/101_ObjectCategories/butterfly/image_0032.jpg")
img = cv2.resize(img, (224,224))
cv2.imshow("Test",img)
cv2.waitKey(0)
cv2.destroyAllWindows()
cv2.waitKey(1)


In [None]:
img = cv2.imread("caltech-101/101_ObjectCategories/butterfly/image_0041.jpg")
img = cv2.resize(img, (224,224)).transpose(2,0,1)

features = model(torch.Tensor([img]))
flattened_features = features.view(features.size(0), -1)[0].tolist()

In [None]:
img = cv2.imread("caltech-101/101_ObjectCategories/brain/image_0033.jpg")
img = cv2.resize(img, (224,224)).transpose(2,0,1)
features = model(torch.Tensor([img]))
flattened_features = features.view(features.size(0), -1)[0].tolist()

In [80]:
results = collection.query(query_embeddings=features, n_results=10, include=["distances"])

In [81]:
results

{'ids': [['id-0',
   'id-8',
   'id-3',
   'id-6',
   'id-4',
   'id-9',
   'id-1',
   'id-7',
   'id-2',
   'id-5']],
 'distances': [[0.0,
   13.09632682800293,
   13.95283317565918,
   14.6825590133667,
   14.836503982543945,
   16.509275436401367,
   16.68521499633789,
   17.651121139526367,
   19.288576126098633,
   24.872194290161133]],
 'metadatas': None,
 'embeddings': None,
 'documents': None,
 'uris': None,
 'data': None,
 'included': ['distances']}