In [None]:
!pip install tensorflow scikit-learn matplotlib


In [None]:
import os
import numpy as np
import matplotlib.pyplot as plt
from tensorflow.keras.applications.resnet50 import ResNet50, preprocess_input
from tensorflow.keras.preprocessing import image
from sklearn.metrics.pairwise import cosine_similarity
from google.colab import files
from zipfile import ZipFile
from PIL import Image


In [None]:
uploaded = files.upload() # Tem que ser um arquivo .zip com as imagens

In [None]:
zip_path = next(iter(uploaded))
with ZipFile(zip_path, 'r') as zip_ref:
    zip_ref.extractall("images")

In [None]:
model = ResNet50(weights='imagenet', include_top=False, pooling='avg')

def extract_features(img_path):
    img = image.load_img(img_path, target_size=(224, 224))
    img_array = image.img_to_array(img)
    img_array = preprocess_input(np.expand_dims(img_array, axis=0))
    features = model.predict(img_array, verbose=0)
    return features.flatten()

In [None]:
image_paths = []
features_list = []

for root, _, files in os.walk("images"):
    for file in files:
        if file.endswith((".jpg", ".png", ".jpeg")):
            path = os.path.join(root, file)
            try:
                features = extract_features(path)
                image_paths.append(path)
                features_list.append(features)
            except:
                print("Erro com imagem:", path)

features_array = np.array(features_list)

In [None]:
from google.colab import files

consulta = files.upload() # imagem normal (.png, .jpg, etc)
img_consulta_path = next(iter(consulta))
vetor_consulta = extract_features(img_consulta_path)

In [None]:
sim_scores = cosine_similarity([vetor_consulta], features_array)[0]
top_indices = np.argsort(sim_scores)[::-1][:5]

In [None]:
def show_images(indices):
    plt.figure(figsize=(15, 5))
    for i, idx in enumerate(indices):
        img = Image.open(image_paths[idx])
        plt.subplot(1, len(indices), i + 1)
        plt.imshow(img)
        plt.axis("off")
        plt.title(f"Score: {sim_scores[idx]:.2f}")
    plt.show()

show_images(top_indices)