## Importando variáveis de ambiente
Esse notebook prevê a existência de 2 variáveis de ambiente no arquivo .env desse projeto:
- DATA_FOLDER
- TEST_DATA_FOLDER
- TRAINED_MODELS_FOLDER

In [1]:
from dotenv import load_dotenv
import os

load_dotenv(dotenv_path=".env", override=True)

DATA_FOLDER = os.getenv("DATA_FOLDER")
TEST_DATA_FOLDER = os.getenv("TEST_DATA_FOLDER")
TRAINED_MODELS_FOLDER = os.getenv("TRAINED_MODELS_FOLDER")

## Bibliotecas

In [2]:
import torch
import torch.nn as nn
import torch.nn.functional as F

from torchvision import models, transforms
from PIL import Image
import pandas as pd
import numpy as np

import sqlite3
import faiss

import matplotlib.pyplot as plt
import math

## Configurações

In [3]:
model_path = os.path.join(TRAINED_MODELS_FOLDER, "best_resnet50.pth")  # caminho do modelo treinado

DB_PATH = os.path.join(DATA_FOLDER, "metadata.db")

device = "cuda" if torch.cuda.is_available() else "cpu"

TOP_K_CLASSES = 3
TOP_K_RESULTS = 3   # por classe

num_classes = 30

## Transform

In [4]:
transform = transforms.Compose([
    transforms.Resize((224, 224)),
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.485, 0.456, 0.406],
                         std=[0.229, 0.224, 0.225])
])

## Carrega o modelo ResNet50 para classificação e extração de características

In [5]:
model = models.resnet50(weights=models.ResNet50_Weights.IMAGENET1K_V2)

model.fc = torch.nn.Linear(model.fc.in_features, num_classes)

model.load_state_dict(torch.load(model_path, map_location=device))

model.to(device)
model.eval()

feature_extractor = nn.Sequential(*list(model.children())[:-1])
feature_extractor.eval()

Sequential(
  (0): Conv2d(3, 64, kernel_size=(7, 7), stride=(2, 2), padding=(3, 3), bias=False)
  (1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
  (2): ReLU(inplace=True)
  (3): MaxPool2d(kernel_size=3, stride=2, padding=1, dilation=1, ceil_mode=False)
  (4): Sequential(
    (0): Bottleneck(
      (conv1): Conv2d(64, 64, kernel_size=(1, 1), stride=(1, 1), bias=False)
      (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (conv2): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
      (bn2): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (conv3): Conv2d(64, 256, kernel_size=(1, 1), stride=(1, 1), bias=False)
      (bn3): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (relu): ReLU(inplace=True)
      (downsample): Sequential(
        (0): Conv2d(64, 256, kernel_size=(1, 1), stride=(1, 1), bias=False)


## Função de classificação e busca por similariadade

In [14]:
import time

def classify_and_find_similar(img_path):
    timings = {
        "classificacao_ms": 0.0,
        "embedding_ms": 0.0,
        "faiss_por_classe_ms": {},  # {class_id: tempo_ms}
        "total_ms": 0.0
    }

    # =============================
    # Início da medição total
    # =============================
    t_total_start = time.perf_counter()

    # 1. Conectar ao banco
    conn = sqlite3.connect(DB_PATH)
    cursor = conn.cursor()

    # 2. Carregar imagem
    img = Image.open(img_path).convert("RGB")
    img_tensor = transform(img).unsqueeze(0).to(device)

    # =============================
    # CLASSIFICAÇÃO
    # =============================
    t_class_start = time.perf_counter()

    with torch.no_grad():
        outputs = model(img_tensor)
        probs = torch.softmax(outputs, dim=1).cpu().numpy().flatten()

    t_class_end = time.perf_counter()

    # 3. Obter top-K classes
    top_classes_idx = probs.argsort()[-TOP_K_CLASSES:][::-1]
    top_classes = []

    for class_id in top_classes_idx:
        cursor.execute(f"SELECT name, index_path FROM class WHERE id = {class_id}")
        row = cursor.fetchone()

        if row is None:
            print(f"[DEBUG][WARN] Classe {class_id} não encontrada no banco!")
            continue

        class_name, index_path = row
        prob = float(probs[class_id])

        top_classes.append({
            "class_id": class_id,
            "class_name": class_name,
            "index_path": index_path,
            "probability": prob
        })

    # =============================
    # EXTRAÇÃO DE EMBEDDING
    # =============================
    t_emb_start = time.perf_counter()

    with torch.no_grad():
        feats = feature_extractor(img_tensor)
        feats = feats.view(feats.size(0), -1)
        feats = torch.nn.functional.normalize(feats, p=2, dim=1)
        feats = feats.cpu().numpy().astype("float32")

    t_emb_end = time.perf_counter()

    # =============================
    # CONSULTAS FAISS
    # =============================

    similar_images = []

    for cls in top_classes:
        class_id = cls["class_id"]

        t_faiss_start = time.perf_counter()

        faiss_index = faiss.read_index(cls["index_path"])
        distances, retrieved_ids = faiss_index.search(feats, TOP_K_RESULTS)

        t_faiss_end = time.perf_counter()
        timings["faiss_por_classe_ms"][class_id] = round((t_faiss_end - t_faiss_start) * 1000, 3)

        retrieved_ids = retrieved_ids[0]
        distances = distances[0]

        # Busca metadados no SQLite
        for img_id, dist in zip(retrieved_ids, distances):
            cursor.execute(f"SELECT image_path FROM image WHERE id = {int(img_id)}")
            result = cursor.fetchone()

            if result is None:
                print(f"[DEBUG][WARN] ID {img_id} não encontrado no banco.")
                continue
            
            image_path = result[0]

            similar_images.append({
                "query_to_class": cls["class_name"],
                "retrieved_image_id": int(img_id),
                "image_path": image_path,
                "distance": float(dist)
            })

    conn.close()

    # Ordena por distância
    similar_images = sorted(similar_images, key=lambda x: x["distance"])

    # =============================
    # Tempo total
    # =============================
    t_total_end = time.perf_counter()
    timings["classificacao_ms"] = round((t_class_end - t_class_start) * 1000, 3)
    timings["embedding_ms"] = round((t_emb_end - t_emb_start) * 1000, 3)
    timings["total_ms"] = round((t_total_end - t_total_start) * 1000, 3)

    # Retorna também as métricas
    return top_classes, similar_images, timings


## Função para exibir os resultados

In [12]:
def show_results(query_image_path, classes, results, save_png=False, png_path="comparacao.png"):
    print("\n=== Exibindo resultados avançados ===")

    # ------------------------------------------------------
    # 1. Mostrar imagem de entrada
    # ------------------------------------------------------
    query_img = Image.open(query_image_path).convert("RGB")

    plt.figure(figsize=(6,6))
    plt.imshow(query_img)
    plt.axis("off")
    plt.title("Imagem de entrada (query)", fontsize=16)
    plt.show()

    # ------------------------------------------------------
    # 2. Determinar ranking das classes
    # ------------------------------------------------------
    ranked_classes = sorted(classes, key=lambda x: x["probability"], reverse=True)

    # Criar mapa class_name → (rank, prob)
    rank_map = {}
    for idx, c in enumerate(ranked_classes):
        rank_map[c["class_name"]] = {
            "rank": idx + 1,
            "prob": c["probability"]
        }

    # Cores por ranking
    rank_colors = {
        1: "blue",
        2: "darkorange",
        3: "gray"
    }

    # ------------------------------------------------------
    # 3. Imprimir Top-K Classes como resumo texto
    # ------------------------------------------------------
    print("\n=== Top Classes ===")
    for c in ranked_classes:
        r = rank_map[c["class_name"]]["rank"]
        p = rank_map[c["class_name"]]["prob"] * 100
        print(f"[Rank {r}] {c['class_name']} — {p:.4f}%")

    # ------------------------------------------------------
    # 4. Agrupar resultados por classe
    # ------------------------------------------------------
    grouped = {}
    for r in results:
        cls = r["query_to_class"]
        if cls not in grouped:
            grouped[cls] = []
        grouped[cls].append(r)

    # ------------------------------------------------------
    # 5. Visualizar cada classe com ranking, probabilidade e grid
    # ------------------------------------------------------
    for cls_name, images in grouped.items():

        info = rank_map.get(cls_name, None)

        if info:
            rank = info["rank"]
            prob = info["prob"] * 100
            rank_suffix = f" (Rank {rank} — {prob:.3f}%)"
            color = rank_colors.get(rank, "black")
        else:
            rank_suffix = ""
            color = "black"

        print(f"\n=== Classe: {cls_name}{rank_suffix} ===")

        images_sorted = sorted(images, key=lambda x: x["distance"])

        cols = 3
        rows = math.ceil(len(images_sorted) / cols)

        # ATIVA um layout mais inteligente
        fig, axes = plt.subplots(
            rows,
            cols,
            figsize=(6*cols, 4*rows),
            constrained_layout=True
        )

        # Permite titles sem serem cobertos
        fig.subplots_adjust(top=0.90)

        # Para aceitar casos com apenas 1 linha
        if rows == 1:
            axes = [axes]
        if cols == 1:
            axes = [[ax] for ax in axes]

        fig.suptitle(f"{cls_name}{rank_suffix}", fontsize=18, color=color)

        for i, item in enumerate(images_sorted):
            r = i // cols
            c = i % cols
            ax = axes[r][c]

            try:
                simg = Image.open(item["image_path"]).convert("RGB")
            except:
                ax.axis("off")
                continue

            ax.imshow(simg)
            ax.axis("off")

            # Destacar match exato
            if item["distance"] == 0.0:
                ax.set_title(
                    f"[MATCH EXATO]\nID: {item['retrieved_image_id']} | Dist: 0.0",
                    fontsize=12,
                    color="green",
                    fontweight="bold"
                )
                for side in ax.spines:
                    ax.spines[side].set_color("green")
                    ax.spines[side].set_linewidth(3)
            else:
                ax.set_title(
                    f"ID: {item['retrieved_image_id']}\nDist: {item['distance']:.4f}",
                    fontsize=11
                )

        plt.show()


    # ------------------------------------------------------
    # 6. PNG único com painel geral (opcional)
    # ------------------------------------------------------
    if save_png:
        print(f"[DEBUG] Salvando comparacao em PNG: {png_path}")

        all_imgs = results[:]
        cols = 4
        rows = math.ceil(len(all_imgs) / cols)

        fig = plt.figure(figsize=(5*cols, 4*rows))
        fig.suptitle("Comparação completa", fontsize=18)

        for i, r in enumerate(all_imgs):
            ax = plt.subplot(rows, cols, i+1)
            simg = Image.open(r["image_path"]).convert("RGB")
            ax.imshow(simg)
            ax.axis("off")
            ax.set_title(f"{r['query_to_class']} | {r['distance']:.3f}", fontsize=10)

        plt.tight_layout()
        plt.subplots_adjust(top=0.92)
        fig.savefig(png_path, dpi=180)
        print(f"[OK] PNG salvo em: {png_path}")

## Testando o CBIR

In [None]:
import os
from ipywidgets import Button, HBox, VBox, Output
from IPython.display import display, clear_output

def run_batch_visualization_notebook(test_folder):
    valid_ext = (".jpg", ".jpeg", ".png", ".bmp", ".webp")
    
    image_files = sorted([
        f for f in os.listdir(test_folder)
        if f.lower().endswith(valid_ext)
    ])

    if not image_files:
        print("Nenhuma imagem encontrada.")
        return

    index = {"i": 0}  # índice compartilhado entre callbacks
    out = Output()

    # Botões
    btn_prev = Button(description="← Anterior", button_style="warning", layout={"width": "150px"})
    btn_next = Button(description="Próxima →", button_style="info", layout={"width": "150px"})

    # ----------------------
    # Renderizar imagem atual
    # ----------------------
    def show_current():
        with out:
            clear_output(wait=True)

            i = index["i"]
            filename = image_files[i]
            img_path = os.path.join(test_folder, filename)

            print(f"[{i+1}/{len(image_files)}] Classificando: {filename}")

            classes, results = classify_and_find_similar(img_path)

            show_results(
                query_image_path=img_path,
                classes=classes,
                results=results
            )

    # ----------------------
    # BOTÃO PRÓXIMA
    # ----------------------
    def on_next(b):
        if index["i"] < len(image_files) - 1:
            index["i"] += 1
            show_current()
        else:
            with out:
                clear_output(wait=True)
                print("Fim da lista de imagens.")

    # ----------------------
    # BOTÃO ANTERIOR
    # ----------------------
    def on_prev(b):
        if index["i"] > 0:
            index["i"] -= 1
            show_current()
        else:
            with out:
                clear_output(wait=True)
                print("Você já está na primeira imagem.")

    # Liga os eventos
    btn_prev.on_click(on_prev)
    btn_next.on_click(on_next)

    # Interface
    display(VBox([
        HBox([btn_prev, btn_next]),
        out
    ]))

    # Mostrar a primeira imagem
    show_current()


# Executar
import os
from ipywidgets import Button, HBox, VBox, Output
from IPython.display import display, clear_output

def run_batch_visualization_notebook(test_folder):
    valid_ext = (".jpg", ".jpeg", ".png", ".bmp", ".webp")
    
    image_files = sorted([
        f for f in os.listdir(test_folder)
        if f.lower().endswith(valid_ext)
    ])

    if not image_files:
        print("Nenhuma imagem encontrada.")
        return

    index = {"i": 0}  # índice compartilhado entre callbacks
    out = Output()

    # Botões
    btn_prev = Button(description="← Anterior", button_style="warning", layout={"width": "150px"})
    btn_next = Button(description="Próxima →", button_style="info", layout={"width": "150px"})

    # ----------------------
    # Renderizar imagem atual
    # ----------------------
    def show_current():
        with out:
            clear_output(wait=True)

            i = index["i"]
            filename = image_files[i]
            img_path = os.path.join(test_folder, filename)

            print(f"[{i+1}/{len(image_files)}] Classificando: {filename}")

            classes, results, metrics = classify_and_find_similar(img_path)

            # Exibe o tempo de query
            print("Tempo total:", metrics["total_ms"], "ms")
            print("Classificação:", metrics["classificacao_ms"], "ms")
            print("Embedding:", metrics["embedding_ms"], "ms")
            print("FAISS por classe:")
            for class_id, t_ms in metrics["faiss_por_classe_ms"].items():
                print(f"  Classe {class_id}: {t_ms} ms")

            show_results(
                query_image_path=img_path,
                classes=classes,
                results=results
            )

    # ----------------------
    # BOTÃO PRÓXIMA
    # ----------------------
    def on_next(b):
        if index["i"] < len(image_files) - 1:
            index["i"] += 1
            show_current()
        else:
            with out:
                clear_output(wait=True)
                print("Fim da lista de imagens.")

    # ----------------------
    # BOTÃO ANTERIOR
    # ----------------------
    def on_prev(b):
        if index["i"] > 0:
            index["i"] -= 1
            show_current()
        else:
            with out:
                clear_output(wait=True)
                print("Você já está na primeira imagem.")

    # Liga os eventos
    btn_prev.on_click(on_prev)
    btn_next.on_click(on_next)

    # Interface
    display(VBox([
        HBox([btn_prev, btn_next]),
        out
    ]))

    # Mostrar a primeira imagem
    show_current()


# Executar
run_batch_visualization_notebook(TEST_DATA_FOLDER)
