# SAMI - Sponge Automatic Model Identification

**Pipeline completo para identificação automática de espécimes de esponjas fósseis usando Vision Transformers**

---

## Índice
1. [Setup e Dependências](#1-setup)
2. [Vision Transformer](#2-vit)
3. [Utilitários Gerais](#3-utils)
4. [CBIR - Content-Based Image Retrieval](#4-cbir)
5. [Utilitários de Avaliação](#5-eval)
6. [Clustering Analysis](#6-clustering)
7. [Multi-Scale Patch Clustering](#7-patches)
8. [Pipeline de Avaliação](#8-pipeline)
9. [Exemplo de Uso](#9-example)

---
## 1. Setup e Dependências <a id="1-setup"></a>

In [None]:
# Instalar dependências (se necessário)
# !pip install torch torchvision numpy pandas matplotlib seaborn scikit-learn umap-learn opencv-python tqdm pillow

In [None]:
# Imports globais
import torch
import torch.nn as nn
import numpy as np
import matplotlib.pyplot as plt
import seaborn as sns
import pandas as pd
import cv2
import gc
import os
import shutil

from pathlib import Path
from PIL import Image
from functools import partial
from typing import List, Tuple, Dict
from tqdm import tqdm

import torchvision.transforms as transforms
from sklearn.neighbors import NearestNeighbors, KNeighborsClassifier
from sklearn.cluster import KMeans, DBSCAN, AgglomerativeClustering
from sklearn.manifold import TSNE
from sklearn.decomposition import PCA
from sklearn.metrics import (
    classification_report, confusion_matrix, accuracy_score,
    f1_score, precision_score, recall_score, silhouette_score,
    calinski_harabasz_score
)
from scipy.cluster.hierarchy import dendrogram, linkage

try:
    import umap
    UMAP_AVAILABLE = True
except ImportError:
    print("UMAP não disponível. Instale com: pip install umap-learn")
    UMAP_AVAILABLE = False

# Configurações
plt.style.use('default')
plt.rcParams['figure.figsize'] = (12, 8)
plt.rcParams['font.size'] = 10

DEVICE = 'cuda' if torch.cuda.is_available() else 'cpu'
print(f"Device: {DEVICE}")

---
## 2. Vision Transformer <a id="2-vit"></a>

Implementação do Vision Transformer baseada em DINO/timm, adaptada do benchmark SCAMPI.

In [None]:
class Attention(nn.Module):
    def __init__(self, dim, num_heads=8, qkv_bias=False, attn_drop=0., proj_drop=0.):
        super().__init__()
        self.num_heads = num_heads
        head_dim = dim // num_heads
        self.scale = head_dim ** -0.5

        self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias)
        self.attn_drop = nn.Dropout(attn_drop)
        self.proj = nn.Linear(dim, dim)
        self.proj_drop = nn.Dropout(proj_drop)

    def forward(self, x):
        B, N, C = x.shape
        qkv = self.qkv(x).reshape(B, N, 3, self.num_heads, C // self.num_heads).permute(2, 0, 3, 1, 4)
        q, k, v = qkv[0], qkv[1], qkv[2]

        attn = (q @ k.transpose(-2, -1)) * self.scale
        attn = attn.softmax(dim=-1)
        attn = self.attn_drop(attn)

        x = (attn @ v).transpose(1, 2).reshape(B, N, C)
        x = self.proj(x)
        x = self.proj_drop(x)
        return x


class Mlp(nn.Module):
    def __init__(self, in_features, hidden_features=None, out_features=None, act_layer=nn.GELU, drop=0.):
        super().__init__()
        out_features = out_features or in_features
        hidden_features = hidden_features or in_features
        self.fc1 = nn.Linear(in_features, hidden_features)
        self.act = act_layer()
        self.fc2 = nn.Linear(hidden_features, out_features)
        self.drop = nn.Dropout(drop)

    def forward(self, x):
        x = self.fc1(x)
        x = self.act(x)
        x = self.drop(x)
        x = self.fc2(x)
        x = self.drop(x)
        return x


class Block(nn.Module):
    def __init__(self, dim, num_heads, mlp_ratio=4., qkv_bias=False, drop=0., attn_drop=0.,
                 act_layer=nn.GELU, norm_layer=nn.LayerNorm):
        super().__init__()
        self.norm1 = norm_layer(dim)
        self.attn = Attention(dim, num_heads=num_heads, qkv_bias=qkv_bias,
                             attn_drop=attn_drop, proj_drop=drop)
        self.norm2 = norm_layer(dim)
        mlp_hidden_dim = int(dim * mlp_ratio)
        self.mlp = Mlp(in_features=dim, hidden_features=mlp_hidden_dim,
                      act_layer=act_layer, drop=drop)

    def forward(self, x):
        x = x + self.attn(self.norm1(x))
        x = x + self.mlp(self.norm2(x))
        return x


class PatchEmbed(nn.Module):
    """Image to Patch Embedding"""
    def __init__(self, img_size=224, patch_size=16, in_chans=3, embed_dim=768):
        super().__init__()
        self.img_size = img_size
        self.patch_size = patch_size
        self.num_patches = (img_size // patch_size) ** 2
        self.proj = nn.Conv2d(in_chans, embed_dim, kernel_size=patch_size, stride=patch_size)

    def forward(self, x):
        B, C, H, W = x.shape
        x = self.proj(x).flatten(2).transpose(1, 2)
        return x


class VisionTransformer(nn.Module):
    """Vision Transformer para SAMI"""
    def __init__(self, img_size=224, patch_size=16, in_chans=3, num_classes=0, embed_dim=768,
                 depth=12, num_heads=12, mlp_ratio=4., qkv_bias=False, drop_rate=0.,
                 attn_drop_rate=0., norm_layer=nn.LayerNorm, **kwargs):
        super().__init__()
        self.num_features = self.embed_dim = embed_dim

        self.patch_embed = PatchEmbed(
            img_size=img_size, patch_size=patch_size, in_chans=in_chans, embed_dim=embed_dim)
        num_patches = self.patch_embed.num_patches

        self.cls_token = nn.Parameter(torch.zeros(1, 1, embed_dim))
        self.pos_embed = nn.Parameter(torch.zeros(1, num_patches + 1, embed_dim))
        self.pos_drop = nn.Dropout(p=drop_rate)

        self.blocks = nn.ModuleList([
            Block(
                dim=embed_dim, num_heads=num_heads, mlp_ratio=mlp_ratio, qkv_bias=qkv_bias,
                drop=drop_rate, attn_drop=attn_drop_rate, norm_layer=norm_layer)
            for i in range(depth)])
        
        self.norm = norm_layer(embed_dim)
        self.head = nn.Linear(embed_dim, num_classes) if num_classes > 0 else nn.Identity()

        torch.nn.init.trunc_normal_(self.pos_embed, std=0.02)
        torch.nn.init.trunc_normal_(self.cls_token, std=0.02)
        self.apply(self._init_weights)

    def _init_weights(self, m):
        if isinstance(m, nn.Linear):
            torch.nn.init.trunc_normal_(m.weight, std=0.02)
            if isinstance(m, nn.Linear) and m.bias is not None:
                nn.init.constant_(m.bias, 0)
        elif isinstance(m, nn.LayerNorm):
            nn.init.constant_(m.bias, 0)
            nn.init.constant_(m.weight, 1.0)

    def interpolate_pos_encoding(self, x, w, h):
        npatch = x.shape[1] - 1
        N = self.pos_embed.shape[1] - 1
        if npatch == N and w == h:
            return self.pos_embed
        class_pos_embed = self.pos_embed[:, 0]
        patch_pos_embed = self.pos_embed[:, 1:]
        dim = x.shape[-1]
        w0 = w // self.patch_embed.patch_size
        h0 = h // self.patch_embed.patch_size
        patch_pos_embed = nn.functional.interpolate(
            patch_pos_embed.reshape(1, int(N ** 0.5), int(N ** 0.5), dim).permute(0, 3, 1, 2),
            scale_factor=(w0 / N ** 0.5, h0 / N ** 0.5),
            mode='bicubic',
        )
        patch_pos_embed = patch_pos_embed.permute(0, 2, 3, 1).view(1, -1, dim)
        return torch.cat((class_pos_embed.unsqueeze(0), patch_pos_embed), dim=1)

    def prepare_tokens(self, x):
        B, nc, w, h = x.shape
        x = self.patch_embed(x)
        cls_tokens = self.cls_token.expand(B, -1, -1)
        x = torch.cat((cls_tokens, x), dim=1)
        x = x + self.interpolate_pos_encoding(x, w, h)
        return self.pos_drop(x)

    def forward(self, x):
        x = self.prepare_tokens(x)
        for blk in self.blocks:
            x = blk(x)
        x = self.norm(x)
        return x[:, 0]  # CLS token

    def get_last_selfattention(self, x):
        x = self.prepare_tokens(x)
        for i, blk in enumerate(self.blocks):
            if i < len(self.blocks) - 1:
                x = blk(x)
            else:
                return blk.attn(blk.norm1(x))
        return None

In [None]:
# Factory functions para diferentes tamanhos de ViT

def vit_tiny(patch_size=16, **kwargs):
    model = VisionTransformer(
        patch_size=patch_size, embed_dim=192, depth=12, num_heads=3, mlp_ratio=4,
        qkv_bias=True, norm_layer=partial(nn.LayerNorm, eps=1e-6), **kwargs)
    return model

def vit_small(patch_size=16, **kwargs):
    model = VisionTransformer(
        patch_size=patch_size, embed_dim=384, depth=12, num_heads=6, mlp_ratio=4,
        qkv_bias=True, norm_layer=partial(nn.LayerNorm, eps=1e-6), **kwargs)
    return model

def vit_base(patch_size=16, **kwargs):
    model = VisionTransformer(
        patch_size=patch_size, embed_dim=768, depth=12, num_heads=12, mlp_ratio=4,
        qkv_bias=True, norm_layer=partial(nn.LayerNorm, eps=1e-6), **kwargs)
    return model

def vit_large(patch_size=16, **kwargs):
    model = VisionTransformer(
        patch_size=patch_size, embed_dim=1024, depth=24, num_heads=16, mlp_ratio=4,
        qkv_bias=True, norm_layer=partial(nn.LayerNorm, eps=1e-6), **kwargs)
    return model

print("Vision Transformer definido com sucesso!")

---
## 3. Utilitários Gerais <a id="3-utils"></a>

Funções para carregamento de imagens, pré-processamento e manipulação de dados.

In [None]:
def load_image(image_path: str) -> Image.Image:
    """Carrega imagem do disco"""
    try:
        img = Image.open(image_path).convert('RGB')
        return img
    except Exception as e:
        print(f"Erro ao carregar {image_path}: {e}")
        return None


def get_transform(img_size: int = 224, is_training: bool = False):
    """
    Pipeline de transformação de imagens
    
    Args:
        img_size: Tamanho alvo
        is_training: Se True, aplica augmentations
    """
    if is_training:
        return transforms.Compose([
            transforms.RandomResizedCrop(img_size, scale=(0.8, 1.0)),
            transforms.RandomHorizontalFlip(),
            transforms.RandomRotation(15),
            transforms.ColorJitter(brightness=0.2, contrast=0.2, saturation=0.2),
            transforms.ToTensor(),
            transforms.Normalize(mean=[0.485, 0.456, 0.406], 
                               std=[0.229, 0.224, 0.225])
        ])
    else:
        return transforms.Compose([
            transforms.Resize(int(img_size * 1.14)),
            transforms.CenterCrop(img_size),
            transforms.ToTensor(),
            transforms.Normalize(mean=[0.485, 0.456, 0.406], 
                               std=[0.229, 0.224, 0.225])
        ])


def preprocess_image(img: Image.Image, img_size: int = 224) -> torch.Tensor:
    """Pré-processa uma imagem para input do modelo"""
    transform = get_transform(img_size, is_training=False)
    return transform(img)


def load_dataset_images(data_path: str, 
                       img_size: int = 224) -> Tuple[torch.Tensor, List[str], List[int]]:
    """
    Carrega todas as imagens de um diretório organizado por classes
    
    Args:
        data_path: Caminho raiz do dataset
        img_size: Tamanho alvo das imagens
    
    Returns:
        (images_tensor, image_paths, labels)
    """
    data_path = Path(data_path)
    
    class_folders = sorted([d for d in data_path.iterdir() if d.is_dir()])
    class_to_idx = {cls_folder.name: idx for idx, cls_folder in enumerate(class_folders)}
    
    images = []
    labels = []
    image_paths = []
    
    transform = get_transform(img_size, is_training=False)
    
    print(f"Carregando imagens de {len(class_folders)} classes...")
    
    for class_folder in class_folders:
        class_name = class_folder.name
        class_idx = class_to_idx[class_name]
        
        image_files = list(class_folder.glob('*.jpg')) + \
                     list(class_folder.glob('*.jpeg')) + \
                     list(class_folder.glob('*.png'))
        
        print(f"  {class_name}: {len(image_files)} imagens")
        
        for img_file in image_files:
            img = load_image(str(img_file))
            if img is not None:
                img_tensor = transform(img)
                images.append(img_tensor)
                labels.append(class_idx)
                image_paths.append(str(img_file))
    
    images_tensor = torch.stack(images)
    
    print(f"\nTotal de imagens: {len(images)}")
    print(f"Shape do tensor: {images_tensor.shape}")
    
    return images_tensor, image_paths, labels


def get_class_names(data_path: str) -> List[str]:
    """Retorna lista de nomes das classes"""
    data_path = Path(data_path)
    class_folders = sorted([d.name for d in data_path.iterdir() if d.is_dir()])
    return class_folders


def save_embeddings(embeddings: np.ndarray, 
                   labels: List[int],
                   image_paths: List[str],
                   output_path: str):
    """Salva embeddings e metadados"""
    data = {
        'embeddings': embeddings,
        'labels': np.array(labels),
        'image_paths': image_paths
    }
    np.savez(output_path, **data)
    print(f"Embeddings salvos em {output_path}")


def load_embeddings(embeddings_path: str) -> Dict:
    """Carrega embeddings salvos"""
    data = np.load(embeddings_path, allow_pickle=True)
    return {
        'embeddings': data['embeddings'],
        'labels': data['labels'],
        'image_paths': data['image_paths']
    }


def create_splits(n_samples: int, 
                 train_ratio: float = 0.7,
                 val_ratio: float = 0.15,
                 seed: int = 42) -> Tuple[List[int], List[int], List[int]]:
    """Cria splits train/val/test"""
    np.random.seed(seed)
    indices = np.random.permutation(n_samples)
    
    n_train = int(n_samples * train_ratio)
    n_val = int(n_samples * val_ratio)
    
    train_indices = indices[:n_train]
    val_indices = indices[n_train:n_train + n_val]
    test_indices = indices[n_train + n_val:]
    
    return train_indices.tolist(), val_indices.tolist(), test_indices.tolist()


def count_parameters(model: torch.nn.Module) -> int:
    """Conta parâmetros treináveis do modelo"""
    return sum(p.numel() for p in model.parameters() if p.requires_grad)


@torch.no_grad()
def extract_features(model: torch.nn.Module,
                    images: torch.Tensor,
                    batch_size: int = 32,
                    device: str = 'cuda') -> np.ndarray:
    """
    Extrai features das imagens usando o modelo
    
    Args:
        model: Modelo extrator de features
        images: Batch de imagens
        batch_size: Tamanho do batch
        device: Dispositivo
    
    Returns:
        Array numpy de features
    """
    model.eval()
    model = model.to(device)
    
    all_features = []
    
    for i in range(0, len(images), batch_size):
        batch = images[i:i+batch_size].to(device)
        features = model(batch)
        all_features.append(features.cpu().numpy())
    
    return np.vstack(all_features)

print("Utilitários gerais definidos!")

---
## 4. CBIR - Content-Based Image Retrieval <a id="4-cbir"></a>

Funções para busca de imagens similares baseada em features visuais.

In [None]:
def build_index(embeddings: np.ndarray, metric: str = 'cosine', n_neighbors: int = 10):
    """
    Constrói índice de vizinhos mais próximos para CBIR
    
    Args:
        embeddings: Feature embeddings (N x D)
        metric: Métrica de distância
        n_neighbors: Número de vizinhos
    """
    if metric == 'cosine':
        embeddings = embeddings / np.linalg.norm(embeddings, axis=1, keepdims=True)
        metric = 'euclidean'
    
    nn_model = NearestNeighbors(n_neighbors=n_neighbors, metric=metric, algorithm='auto')
    nn_model.fit(embeddings)
    
    return nn_model


def retrieve_similar(query_embedding: np.ndarray,
                    nn_model: NearestNeighbors,
                    k: int = 10) -> Tuple[np.ndarray, np.ndarray]:
    """Recupera k imagens mais similares a uma query"""
    if query_embedding.ndim == 1:
        query_embedding = query_embedding.reshape(1, -1)
    
    distances, indices = nn_model.kneighbors(query_embedding, n_neighbors=k)
    return distances[0], indices[0]


def batch_retrieve(query_embeddings: np.ndarray,
                  nn_model: NearestNeighbors,
                  k: int = 10) -> Tuple[np.ndarray, np.ndarray]:
    """Recupera imagens similares para múltiplas queries"""
    distances, indices = nn_model.kneighbors(query_embeddings, n_neighbors=k)
    return distances, indices


def compute_recall_at_k(query_labels: np.ndarray,
                       retrieved_indices: np.ndarray,
                       database_labels: np.ndarray,
                       k_values: List[int] = [1, 5, 10]) -> dict:
    """Calcula Recall@k"""
    results = {}
    n_queries = len(query_labels)
    
    for k in k_values:
        correct = 0
        for i in range(n_queries):
            query_label = query_labels[i]
            retrieved_labels = database_labels[retrieved_indices[i, :k]]
            if query_label in retrieved_labels:
                correct += 1
        
        recall = correct / n_queries
        results[f'Recall@{k}'] = recall
    
    return results


def compute_precision_at_k(query_labels: np.ndarray,
                           retrieved_indices: np.ndarray,
                           database_labels: np.ndarray,
                           k_values: List[int] = [1, 5, 10]) -> dict:
    """Calcula Precision@k"""
    results = {}
    n_queries = len(query_labels)
    
    for k in k_values:
        total_precision = 0.0
        for i in range(n_queries):
            query_label = query_labels[i]
            retrieved_labels = database_labels[retrieved_indices[i, :k]]
            n_relevant = np.sum(retrieved_labels == query_label)
            precision = n_relevant / k
            total_precision += precision
        
        avg_precision = total_precision / n_queries
        results[f'Precision@{k}'] = avg_precision
    
    return results


def compute_map_at_k(query_labels: np.ndarray,
                    retrieved_indices: np.ndarray,
                    database_labels: np.ndarray,
                    k: int = 10) -> float:
    """Calcula Mean Average Precision@k"""
    n_queries = len(query_labels)
    average_precisions = []
    
    for i in range(n_queries):
        query_label = query_labels[i]
        retrieved_labels = database_labels[retrieved_indices[i, :k]]
        
        precisions = []
        n_relevant = 0
        
        for j in range(k):
            if retrieved_labels[j] == query_label:
                n_relevant += 1
                precision_at_j = n_relevant / (j + 1)
                precisions.append(precision_at_j)
        
        if len(precisions) > 0:
            avg_precision = np.mean(precisions)
        else:
            avg_precision = 0.0
        
        average_precisions.append(avg_precision)
    
    return np.mean(average_precisions)


def evaluate_cbir(embeddings: np.ndarray,
                 labels: np.ndarray,
                 k_values: List[int] = [1, 5, 10],
                 metric: str = 'cosine') -> dict:
    """
    Avalia performance de CBIR usando leave-one-out
    
    Args:
        embeddings: Feature embeddings
        labels: Labels das classes
        k_values: Valores de k para métricas
        metric: Métrica de distância
    """
    n_samples = len(embeddings)
    max_k = max(k_values) + 1
    
    if metric == 'cosine':
        embeddings = embeddings / np.linalg.norm(embeddings, axis=1, keepdims=True)
        metric = 'euclidean'
    
    nn_model = NearestNeighbors(n_neighbors=max_k, metric=metric)
    nn_model.fit(embeddings)
    
    _, indices = nn_model.kneighbors(embeddings)
    indices = indices[:, 1:]  # Remove a própria query
    
    results = {}
    results.update(compute_recall_at_k(labels, indices, labels, k_values))
    results.update(compute_precision_at_k(labels, indices, labels, k_values))
    results['MAP@10'] = compute_map_at_k(labels, indices, labels, k=10)
    
    return results


def find_hard_negatives(embeddings: np.ndarray,
                       labels: np.ndarray,
                       k: int = 10) -> List[Tuple[int, int, float]]:
    """Encontra pares de hard negatives (imagens similares de classes diferentes)"""
    embeddings = embeddings / np.linalg.norm(embeddings, axis=1, keepdims=True)
    
    nn_model = NearestNeighbors(n_neighbors=k+1, metric='euclidean')
    nn_model.fit(embeddings)
    
    distances, indices = nn_model.kneighbors(embeddings)
    
    hard_negatives = []
    
    for i in range(len(embeddings)):
        query_label = labels[i]
        for j, idx in enumerate(indices[i, 1:]):
            if labels[idx] != query_label:
                hard_negatives.append((i, idx, distances[i, j+1]))
                break
    
    return hard_negatives

print("CBIR definido!")

---
## 5. Utilitários de Avaliação <a id="5-eval"></a>

Métricas de classificação, visualizações e relatórios.

In [None]:
def train_knn_classifier(train_embeddings: np.ndarray,
                        train_labels: np.ndarray,
                        n_neighbors: int = 5,
                        metric: str = 'cosine') -> KNeighborsClassifier:
    """Treina classificador KNN"""
    knn = KNeighborsClassifier(n_neighbors=n_neighbors, metric=metric)
    knn.fit(train_embeddings, train_labels)
    return knn


def evaluate_knn(knn: KNeighborsClassifier,
                test_embeddings: np.ndarray,
                test_labels: np.ndarray,
                class_names: List[str] = None) -> Dict:
    """Avalia classificador KNN"""
    predictions = knn.predict(test_embeddings)
    
    metrics = {
        'accuracy': accuracy_score(test_labels, predictions),
        'macro_f1': f1_score(test_labels, predictions, average='macro'),
        'weighted_f1': f1_score(test_labels, predictions, average='weighted'),
        'macro_precision': precision_score(test_labels, predictions, average='macro', zero_division=0),
        'macro_recall': recall_score(test_labels, predictions, average='macro', zero_division=0),
    }
    
    if class_names is not None:
        report = classification_report(test_labels, predictions,
                                      target_names=class_names,
                                      zero_division=0,
                                      output_dict=True)
        metrics['classification_report'] = report
    
    cm = confusion_matrix(test_labels, predictions)
    metrics['confusion_matrix'] = cm
    metrics['predictions'] = predictions
    
    return metrics


def plot_confusion_matrix(cm: np.ndarray,
                         class_names: List[str],
                         title: str = 'Confusion Matrix',
                         save_path: str = None,
                         figsize: tuple = (12, 10)):
    """Plota matriz de confusão"""
    plt.figure(figsize=figsize)
    
    cm_normalized = cm.astype('float') / cm.sum(axis=1)[:, np.newaxis]
    
    sns.heatmap(cm_normalized, annot=True, fmt='.2f', cmap='Blues',
                xticklabels=class_names, yticklabels=class_names,
                cbar_kws={'label': 'Normalized Count'})
    
    plt.title(title)
    plt.ylabel('True Label')
    plt.xlabel('Predicted Label')
    plt.xticks(rotation=45, ha='right')
    plt.yticks(rotation=0)
    plt.tight_layout()
    
    if save_path:
        plt.savefig(save_path, dpi=300, bbox_inches='tight')
        print(f"Matriz de confusão salva em {save_path}")
    
    plt.show()
    plt.close()


def plot_tsne(embeddings: np.ndarray,
             labels: np.ndarray,
             class_names: List[str],
             title: str = 't-SNE Visualization',
             save_path: str = None,
             figsize: tuple = (12, 10),
             perplexity: int = 30):
    """Cria visualização t-SNE dos embeddings"""
    print("Computando t-SNE...")
    
    tsne = TSNE(n_components=2, random_state=42, perplexity=perplexity)
    embeddings_2d = tsne.fit_transform(embeddings)
    
    plt.figure(figsize=figsize)
    
    for class_idx, class_name in enumerate(class_names):
        mask = labels == class_idx
        plt.scatter(embeddings_2d[mask, 0], embeddings_2d[mask, 1],
                   label=class_name, alpha=0.6, s=50)
    
    plt.title(title)
    plt.xlabel('t-SNE Component 1')
    plt.ylabel('t-SNE Component 2')
    plt.legend(bbox_to_anchor=(1.05, 1), loc='upper left')
    plt.tight_layout()
    
    if save_path:
        plt.savefig(save_path, dpi=300, bbox_inches='tight')
        print(f"t-SNE salvo em {save_path}")
    
    plt.show()
    plt.close()


def compute_class_wise_metrics(true_labels: np.ndarray,
                               predictions: np.ndarray,
                               class_names: List[str]) -> pd.DataFrame:
    """Calcula métricas por classe"""
    n_classes = len(class_names)
    metrics = []
    
    for class_idx in range(n_classes):
        true_binary = (true_labels == class_idx).astype(int)
        pred_binary = (predictions == class_idx).astype(int)
        
        tp = np.sum((true_binary == 1) & (pred_binary == 1))
        fp = np.sum((true_binary == 0) & (pred_binary == 1))
        fn = np.sum((true_binary == 1) & (pred_binary == 0))
        tn = np.sum((true_binary == 0) & (pred_binary == 0))
        
        precision = tp / (tp + fp) if (tp + fp) > 0 else 0
        recall = tp / (tp + fn) if (tp + fn) > 0 else 0
        f1 = 2 * (precision * recall) / (precision + recall) if (precision + recall) > 0 else 0
        
        support = np.sum(true_labels == class_idx)
        
        metrics.append({
            'Class': class_names[class_idx],
            'Precision': precision,
            'Recall': recall,
            'F1-Score': f1,
            'Support': support
        })
    
    return pd.DataFrame(metrics)


def save_evaluation_report(metrics: Dict,
                          class_names: List[str],
                          save_path: str):
    """Salva relatório de avaliação"""
    with open(save_path, 'w') as f:
        f.write("=" * 80 + "\n")
        f.write("SAMI - Evaluation Report\n")
        f.write("=" * 80 + "\n\n")
        
        f.write("Overall Metrics:\n")
        f.write("-" * 80 + "\n")
        f.write(f"Accuracy: {metrics['accuracy']:.4f}\n")
        f.write(f"Macro F1-Score: {metrics['macro_f1']:.4f}\n")
        f.write(f"Weighted F1-Score: {metrics['weighted_f1']:.4f}\n")
        f.write(f"Macro Precision: {metrics['macro_precision']:.4f}\n")
        f.write(f"Macro Recall: {metrics['macro_recall']:.4f}\n\n")
        
        if 'classification_report' in metrics:
            f.write("Per-Class Metrics:\n")
            f.write("-" * 80 + "\n")
            report = metrics['classification_report']
            
            for class_name in class_names:
                if class_name in report:
                    class_metrics = report[class_name]
                    f.write(f"\n{class_name}:\n")
                    f.write(f"  Precision: {class_metrics['precision']:.4f}\n")
                    f.write(f"  Recall: {class_metrics['recall']:.4f}\n")
                    f.write(f"  F1-Score: {class_metrics['f1-score']:.4f}\n")
                    f.write(f"  Support: {class_metrics['support']}\n")
        
        f.write("\n" + "=" * 80 + "\n")
    
    print(f"Relatório salvo em {save_path}")


def compare_k_values(embeddings: np.ndarray,
                    labels: np.ndarray,
                    k_values: List[int] = [1, 3, 5, 7, 9, 11],
                    metric: str = 'cosine') -> pd.DataFrame:
    """Compara performance do KNN para diferentes valores de k"""
    results = []
    
    for k in k_values:
        knn = KNeighborsClassifier(n_neighbors=k, metric=metric)
        knn.fit(embeddings, labels)
        predictions = knn.predict(embeddings)
        
        acc = accuracy_score(labels, predictions)
        f1 = f1_score(labels, predictions, average='macro')
        
        results.append({'k': k, 'accuracy': acc, 'macro_f1': f1})
    
    return pd.DataFrame(results)

print("Utilitários de avaliação definidos!")

---
## 6. Clustering Analysis <a id="6-clustering"></a>

Descoberta automática de grupos de esponjas sem labels prévios.

In [None]:
def compute_tsne_clustering(embeddings, perplexity=30, random_state=42):
    """Computa t-SNE para visualização de clustering"""
    print(f"Computando t-SNE (perplexity={perplexity})...")
    tsne = TSNE(n_components=2, perplexity=perplexity, random_state=random_state)
    coords_2d = tsne.fit_transform(embeddings)
    return coords_2d


def plot_tsne_clusters(coords_2d, labels, title, save_path=None, n_clusters=None):
    """Plota t-SNE com cores por cluster"""
    plt.figure(figsize=(12, 10))
    
    if n_clusters:
        colors = plt.cm.rainbow(np.linspace(0, 1, n_clusters))
        for i in range(n_clusters):
            mask = labels == i
            plt.scatter(coords_2d[mask, 0], coords_2d[mask, 1], 
                       c=[colors[i]], label=f'Cluster {i}', 
                       alpha=0.6, s=50, edgecolors='black', linewidth=0.5)
        plt.legend(bbox_to_anchor=(1.05, 1), loc='upper left')
    else:
        unique_labels = set(labels)
        colors = plt.cm.rainbow(np.linspace(0, 1, len(unique_labels)))
        
        for i, color in zip(unique_labels, colors):
            if i == -1:
                color = 'black'
                label = 'Noise'
            else:
                label = f'Cluster {i}'
            
            mask = labels == i
            plt.scatter(coords_2d[mask, 0], coords_2d[mask, 1],
                       c=[color], label=label, alpha=0.6, s=50,
                       edgecolors='black', linewidth=0.5)
        plt.legend(bbox_to_anchor=(1.05, 1), loc='upper left')
    
    plt.title(title, fontsize=14, fontweight='bold')
    plt.xlabel('t-SNE Component 1')
    plt.ylabel('t-SNE Component 2')
    plt.tight_layout()
    
    if save_path:
        plt.savefig(save_path, dpi=300, bbox_inches='tight')
        print(f"Plot salvo em {save_path}")
    
    plt.show()
    plt.close()


def kmeans_clustering(embeddings, n_clusters_range, coords_2d=None, output_dir=None):
    """Executa K-Means com múltiplos valores de k"""
    print("\n" + "="*80)
    print("K-MEANS CLUSTERING")
    print("="*80)
    
    results = []
    
    for n_clusters in n_clusters_range:
        print(f"\nTestando k={n_clusters}...")
        
        kmeans = KMeans(n_clusters=n_clusters, random_state=42, n_init=10)
        labels = kmeans.fit_predict(embeddings)
        
        silhouette = silhouette_score(embeddings, labels)
        calinski = calinski_harabasz_score(embeddings, labels)
        inertia = kmeans.inertia_
        
        print(f"  Silhouette Score: {silhouette:.4f}")
        print(f"  Calinski-Harabasz: {calinski:.2f}")
        print(f"  Inertia: {inertia:.2f}")
        
        unique, counts = np.unique(labels, return_counts=True)
        print(f"  Amostras por cluster: {dict(zip(unique, counts))}")
        
        results.append({
            'n_clusters': n_clusters,
            'silhouette': silhouette,
            'calinski_harabasz': calinski,
            'inertia': inertia,
            'labels': labels
        })
        
        if coords_2d is not None:
            save_path = str(output_dir / f'kmeans_k{n_clusters}.png') if output_dir else None
            plot_tsne_clusters(
                coords_2d, labels,
                f'K-Means (k={n_clusters}) | Silhouette: {silhouette:.3f}',
                save_path=save_path,
                n_clusters=n_clusters
            )
    
    df = pd.DataFrame([{
        'n_clusters': r['n_clusters'],
        'silhouette': r['silhouette'],
        'calinski_harabasz': r['calinski_harabasz'],
        'inertia': r['inertia']
    } for r in results])
    
    best_idx = df['silhouette'].idxmax()
    best_k = df.loc[best_idx, 'n_clusters']
    print(f"\nMelhor k por Silhouette Score: {int(best_k)}")
    
    return results, df


def dbscan_clustering(embeddings, coords_2d=None, output_dir=None):
    """Executa DBSCAN (density-based)"""
    print("\n" + "="*80)
    print("DBSCAN CLUSTERING")
    print("="*80)
    
    eps_values = [0.3, 0.5, 0.7, 1.0]
    
    for eps in eps_values:
        print(f"\nTestando eps={eps}...")
        
        dbscan = DBSCAN(eps=eps, min_samples=5, metric='euclidean')
        labels = dbscan.fit_predict(embeddings)
        
        n_clusters = len(set(labels)) - (1 if -1 in labels else 0)
        n_noise = list(labels).count(-1)
        
        print(f"  Clusters encontrados: {n_clusters}")
        print(f"  Pontos de ruído: {n_noise}")
        
        if n_clusters > 1:
            valid_mask = labels != -1
            if valid_mask.sum() > 1:
                silhouette = silhouette_score(embeddings[valid_mask], labels[valid_mask])
                print(f"  Silhouette Score: {silhouette:.4f}")
        
        unique, counts = np.unique(labels, return_counts=True)
        print(f"  Amostras por cluster: {dict(zip(unique, counts))}")
        
        if coords_2d is not None:
            save_path = str(output_dir / f'dbscan_eps{eps}.png') if output_dir else None
            plot_tsne_clusters(
                coords_2d, labels,
                f'DBSCAN (eps={eps}) | Clusters: {n_clusters}, Noise: {n_noise}',
                save_path=save_path
            )


def hierarchical_clustering(embeddings, coords_2d=None, output_dir=None):
    """Executa clustering hierárquico com dendrograma"""
    print("\n" + "="*80)
    print("HIERARCHICAL CLUSTERING")
    print("="*80)
    
    print("Computando matriz de linkage...")
    linkage_matrix = linkage(embeddings, method='ward')
    
    plt.figure(figsize=(15, 8))
    dendrogram(linkage_matrix, no_labels=True)
    plt.title('Dendrogram - Hierarchical Clustering', fontsize=14, fontweight='bold')
    plt.xlabel('Sample Index')
    plt.ylabel('Distance')
    plt.tight_layout()
    
    if output_dir:
        plt.savefig(output_dir / 'hierarchical_dendrogram.png', dpi=300, bbox_inches='tight')
        print(f"Dendrograma salvo em {output_dir / 'hierarchical_dendrogram.png'}")
    
    plt.show()
    plt.close()
    
    for n in [3, 5, 7, 10]:
        clustering = AgglomerativeClustering(n_clusters=n, linkage='ward')
        labels = clustering.fit_predict(embeddings)
        
        silhouette = silhouette_score(embeddings, labels)
        print(f"\nn_clusters={n}: Silhouette={silhouette:.4f}")
        
        unique, counts = np.unique(labels, return_counts=True)
        print(f"  Amostras por cluster: {dict(zip(unique, counts))}")
        
        if coords_2d is not None:
            save_path = str(output_dir / f'hierarchical_n{n}.png') if output_dir else None
            plot_tsne_clusters(
                coords_2d, labels,
                f'Hierarchical (n={n}) | Silhouette: {silhouette:.3f}',
                save_path=save_path,
                n_clusters=n
            )


def save_cluster_examples(image_paths, labels, output_dir, max_per_cluster=20):
    """Salva imagens exemplo de cada cluster"""
    print("\n" + "="*80)
    print("SALVANDO EXEMPLOS DOS CLUSTERS")
    print("="*80)
    
    clusters_dir = Path(output_dir) / 'cluster_images'
    clusters_dir.mkdir(exist_ok=True)
    
    unique_labels = sorted(set(labels))
    
    for cluster_id in unique_labels:
        if cluster_id == -1:
            cluster_name = 'noise'
        else:
            cluster_name = f'cluster_{cluster_id}'
        
        cluster_dir = clusters_dir / cluster_name
        cluster_dir.mkdir(exist_ok=True)
        
        cluster_mask = labels == cluster_id
        cluster_paths = [image_paths[i] for i in range(len(labels)) if cluster_mask[i]]
        
        for i, img_path in enumerate(cluster_paths[:max_per_cluster]):
            src = Path(img_path)
            dst = cluster_dir / f'{i:03d}_{src.name}'
            shutil.copy(src, dst)
        
        print(f"  {cluster_name}: {len(cluster_paths)} imagens ({min(len(cluster_paths), max_per_cluster)} salvas)")
    
    print(f"\nExemplos salvos em {clusters_dir}")

print("Clustering Analysis definido!")

---
## 7. Multi-Scale Patch Clustering <a id="7-patches"></a>

Análise de patches multi-escala com UMAP, normalização percentil e remoção de background.

In [None]:
def percentile_normalize_image(img, lower_percentile=5, upper_percentile=95):
    """
    Normaliza imagem usando clipping de percentis para melhor contraste.
    Aplica normalização por canal para preservar padrões RGB.
    """
    img_normalized = np.zeros_like(img, dtype=np.float32)
    
    for c in range(3):  # R, G, B
        channel = img[:, :, c].astype(np.float32)
        
        p_low = np.percentile(channel, lower_percentile)
        p_high = np.percentile(channel, upper_percentile)
        
        channel_clipped = np.clip(channel, p_low, p_high)
        
        if p_high > p_low:
            channel_normalized = (channel_clipped - p_low) / (p_high - p_low) * 255.0
        else:
            channel_normalized = channel_clipped
        
        img_normalized[:, :, c] = channel_normalized
    
    return img_normalized.astype(np.uint8)


def is_valid_patch(patch, min_content_ratio=0.7):
    """
    Verifica se o patch contém conteúdo suficiente (não é maioritariamente background)
    """
    gray = cv2.cvtColor(patch, cv2.COLOR_RGB2GRAY)
    background_mask = (gray > 240) | (gray < 15)
    content_ratio = 1.0 - (np.sum(background_mask) / background_mask.size)
    return content_ratio >= min_content_ratio


def extract_patches_from_image(image_path, window_sizes, stride, max_patches=50, min_content_ratio=0.7):
    """
    Extrai patches de uma imagem com:
    - Normalização percentil (5%-95%)
    - Remoção de background
    - Preservação RGB
    """
    img = cv2.imread(str(image_path))
    if img is None:
        return []
    
    img_rgb = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
    img_normalized = percentile_normalize_image(img_rgb, lower_percentile=5, upper_percentile=95)
    
    h, w = img_normalized.shape[:2]
    patches = []
    
    for window_size in window_sizes:
        valid_patches_this_size = 0
        max_per_size = max_patches // len(window_sizes)
        
        positions_y = list(range(0, h - window_size + 1, stride))
        positions_x = list(range(0, w - window_size + 1, stride))
        
        np.random.seed(42)
        positions = [(y, x) for y in positions_y for x in positions_x]
        np.random.shuffle(positions)
        
        for y, x in positions:
            if valid_patches_this_size >= max_per_size:
                break
            
            patch = img_normalized[y:y+window_size, x:x+window_size].copy()
            
            if not is_valid_patch(patch, min_content_ratio):
                continue
            
            metadata = {
                'image_path': str(image_path),
                'image_name': Path(image_path).stem,
                'window_size': window_size,
                'x': x,
                'y': y,
                'coords': f'{x}_{y}'
            }
            
            patches.append((patch, metadata))
            valid_patches_this_size += 1
    
    del img, img_rgb, img_normalized
    gc.collect()
    
    return patches


def extract_all_patches(image_paths, class_names, window_sizes, stride, max_patches_per_image, min_content_ratio):
    """Extrai patches de todas as imagens com pré-processamento avançado"""
    all_patches = []
    
    print(f"\nExtraindo patches com:")
    print(f"  - Normalização percentil (5%-95%)")
    print(f"  - Remoção de background (min content: {min_content_ratio*100}%)")
    print(f"  - Preservação de padrões RGB")
    print()
    
    for img_path, class_name in tqdm(zip(image_paths, class_names), total=len(image_paths)):
        patches = extract_patches_from_image(
            img_path, window_sizes, stride, max_patches_per_image, min_content_ratio
        )
        
        for patch, metadata in patches:
            metadata['class'] = class_name
            metadata['prefix'] = f"{class_name}/{metadata['image_name']}/window_{metadata['window_size']}"
            all_patches.append((patch, metadata))
        
        if len(all_patches) % 100 == 0:
            gc.collect()
    
    print(f"Total de patches extraídos: {len(all_patches)}")
    
    for class_name in set(class_names):
        n_patches = sum(1 for _, m in all_patches if m['class'] == class_name)
        print(f"  {class_name}: {n_patches} patches")
    
    return all_patches


def extract_features_from_patches(patches, model, device, batch_size=32):
    """Extrai features dos patches preservando informação RGB"""
    print("\nExtraindo features dos patches...")
    
    model.eval()
    all_features = []
    
    for i in tqdm(range(0, len(patches), batch_size)):
        batch_patches = patches[i:i+batch_size]
        
        batch_tensors = []
        for patch_img, _ in batch_patches:
            patch_pil = Image.fromarray(patch_img)
            patch_tensor = preprocess_image(patch_pil, img_size=224)
            batch_tensors.append(patch_tensor)
        
        batch = torch.stack(batch_tensors).to(device)
        
        with torch.no_grad():
            features = model(batch)
        
        all_features.append(features.cpu().numpy())
        
        del batch, features
        if device == 'cuda':
            torch.cuda.empty_cache()
        
        if i % (batch_size * 10) == 0:
            gc.collect()
    
    features_array = np.vstack(all_features)
    print(f"Shape das features: {features_array.shape}")
    
    return features_array


def cluster_patches(features, n_clusters, metadata_list):
    """Clusteriza patches"""
    print(f"\nClusterizando patches em {n_clusters} grupos...")
    
    features_norm = features / (np.linalg.norm(features, axis=1, keepdims=True) + 1e-8)
    
    kmeans = KMeans(n_clusters=n_clusters, random_state=42, n_init=10, max_iter=300)
    labels = kmeans.fit_predict(features_norm)
    
    silhouette = silhouette_score(features_norm, labels)
    print(f"Silhouette Score: {silhouette:.4f}")
    
    print("\nEstatísticas dos Clusters:")
    for cluster_id in range(n_clusters):
        mask = labels == cluster_id
        n_patches = mask.sum()
        
        classes = [metadata_list[i]['class'] for i in range(len(labels)) if mask[i]]
        class_counts = {}
        for c in classes:
            class_counts[c] = class_counts.get(c, 0) + 1
        
        print(f"  Cluster {cluster_id}: {n_patches} patches - {class_counts}")
    
    return labels, kmeans, silhouette


def visualize_umap_patches(features, labels, metadata_list, save_path=None, silhouette_val=None):
    """Cria visualização UMAP"""
    if not UMAP_AVAILABLE:
        print("UMAP não disponível. Pulando visualização UMAP.")
        return
    
    print("\nComputando projeção UMAP...")
    
    if len(features) > 5000:
        print(f"  Amostrando 5000 patches de {len(features)} para visualização...")
        indices = np.random.choice(len(features), 5000, replace=False)
        features_sample = features[indices]
        labels_sample = labels[indices]
        metadata_sample = [metadata_list[i] for i in indices]
    else:
        features_sample = features
        labels_sample = labels
        metadata_sample = metadata_list
    
    reducer = umap.UMAP(
        n_neighbors=15,
        min_dist=0.1,
        n_components=2,
        metric='cosine',
        random_state=42
    )
    coords_2d = reducer.fit_transform(features_sample)
    
    fig, axes = plt.subplots(1, 2, figsize=(20, 8))
    
    # Por cluster
    ax1 = axes[0]
    n_clusters = len(set(labels_sample))
    colors = plt.cm.rainbow(np.linspace(0, 1, n_clusters))
    
    for cluster_id in range(n_clusters):
        mask = labels_sample == cluster_id
        if mask.sum() > 0:
            ax1.scatter(coords_2d[mask, 0], coords_2d[mask, 1],
                       c=[colors[cluster_id]], label=f'Cluster {cluster_id}',
                       alpha=0.6, s=20, edgecolors='black', linewidth=0.3)
    
    title = 'UMAP: Por Cluster'
    if silhouette_val:
        title += f'\nSilhouette: {silhouette_val:.3f}'
    ax1.set_title(title, fontsize=14, fontweight='bold')
    ax1.set_xlabel('UMAP 1')
    ax1.set_ylabel('UMAP 2')
    ax1.legend(bbox_to_anchor=(1.05, 1), loc='upper left', fontsize=8, ncol=2)
    
    # Por classe
    ax2 = axes[1]
    classes = list(set([m['class'] for m in metadata_sample]))
    class_colors = plt.cm.Set1(np.linspace(0, 1, len(classes)))
    
    for i, class_name in enumerate(classes):
        mask = np.array([m['class'] == class_name for m in metadata_sample])
        if mask.sum() > 0:
            ax2.scatter(coords_2d[mask, 0], coords_2d[mask, 1],
                       c=[class_colors[i]], label=class_name,
                       alpha=0.6, s=20, edgecolors='black', linewidth=0.3)
    
    ax2.set_title('UMAP: Por Classe Original', fontsize=14, fontweight='bold')
    ax2.set_xlabel('UMAP 1')
    ax2.set_ylabel('UMAP 2')
    ax2.legend(bbox_to_anchor=(1.05, 1), loc='upper left', fontsize=10)
    
    plt.tight_layout()
    
    if save_path:
        plt.savefig(save_path, dpi=300, bbox_inches='tight')
        print(f"Plot UMAP salvo em {save_path}")
    
    plt.show()
    plt.close()


def visualize_patch_examples(patches, labels, metadata_list, save_path=None, n_examples=5):
    """Visualiza patches exemplo de cada cluster"""
    print("\nCriando visualização de patches...")
    
    n_clusters = len(set(labels))
    
    fig, axes = plt.subplots(n_clusters, n_examples, figsize=(n_examples*2, n_clusters*2))
    
    if n_clusters == 1:
        axes = axes.reshape(1, -1)
    
    for cluster_id in range(n_clusters):
        cluster_indices = [i for i, label in enumerate(labels) if label == cluster_id]
        
        if len(cluster_indices) > n_examples:
            sampled_indices = np.random.choice(cluster_indices, n_examples, replace=False)
        else:
            sampled_indices = cluster_indices[:n_examples]
        
        for col_idx, patch_idx in enumerate(sampled_indices):
            patch_img, metadata = patches[patch_idx]
            
            axes[cluster_id, col_idx].imshow(patch_img)
            axes[cluster_id, col_idx].axis('off')
            
            if col_idx == 0:
                axes[cluster_id, col_idx].set_ylabel(f'Cluster {cluster_id}', 
                                                     fontsize=10, fontweight='bold')
            
            class_name = metadata['class']
            axes[cluster_id, col_idx].set_title(class_name, fontsize=8)
    
    plt.suptitle('Patches Exemplo por Cluster', fontsize=14, fontweight='bold')
    plt.tight_layout()
    
    if save_path:
        plt.savefig(save_path, dpi=200, bbox_inches='tight')
        print(f"Exemplos salvos em {save_path}")
    
    plt.show()
    plt.close()

print("Multi-Scale Patch Clustering definido!")

---
## 8. Pipeline de Avaliação <a id="8-pipeline"></a>

Pipeline completo para avaliação do modelo SAMI.

In [None]:
def run_full_evaluation(data_path, 
                        model_path=None,
                        model_arch='vit_small',
                        output_dir='./results',
                        k_neighbors=7,
                        batch_size=32,
                        img_size=224):
    """
    Pipeline completo de avaliação SAMI
    
    Args:
        data_path: Caminho para o dataset
        model_path: Caminho para pesos do modelo (opcional)
        model_arch: Arquitetura ('vit_small' ou 'vit_base')
        output_dir: Diretório de saída
        k_neighbors: k para KNN
        batch_size: Tamanho do batch
        img_size: Tamanho das imagens
    """
    output_dir = Path(output_dir)
    output_dir.mkdir(parents=True, exist_ok=True)
    
    print("="*80)
    print("SAMI - Sponge Automatic Model Identification")
    print("Pipeline de Avaliação")
    print("="*80)
    print(f"Data path: {data_path}")
    print(f"Output: {output_dir}")
    print("="*80 + "\n")
    
    # Carregar nomes das classes
    class_names = get_class_names(data_path)
    print(f"Classes encontradas ({len(class_names)}):")
    for i, name in enumerate(class_names):
        print(f"  {i}: {name}")
    print()
    
    # Carregar dataset
    print("Carregando imagens...")
    images_tensor, image_paths, labels = load_dataset_images(data_path, img_size=img_size)
    labels = np.array(labels)
    
    # Carregar modelo
    print(f"\nCarregando modelo {model_arch}...")
    if model_arch == 'vit_small':
        model = vit_small(patch_size=16)
    elif model_arch == 'vit_base':
        model = vit_base(patch_size=16)
    
    if model_path:
        print(f"Carregando pesos de {model_path}")
        state_dict = torch.load(model_path, map_location='cpu')
        model.load_state_dict(state_dict)
    else:
        print("AVISO: Usando inicialização aleatória (sem pesos pré-treinados)")
    
    model.eval()
    model = model.to(DEVICE)
    print(f"Modelo carregado em {DEVICE}")
    print(f"Parâmetros: {count_parameters(model):,}\n")
    
    # Extrair features
    print("Extraindo features...")
    embeddings = extract_features(model, images_tensor, batch_size=batch_size, device=DEVICE)
    print(f"Shape dos embeddings: {embeddings.shape}\n")
    
    # Avaliação 1: CBIR
    print("="*80)
    print("AVALIAÇÃO 1: Content-Based Image Retrieval (CBIR)")
    print("="*80)
    cbir_results = evaluate_cbir(embeddings, labels, k_values=[1, 5, 10, 20], metric='cosine')
    print("Resultados CBIR:")
    for metric, value in cbir_results.items():
        print(f"  {metric}: {value:.4f}")
    print()
    
    # Avaliação 2: KNN
    print("="*80)
    print("AVALIAÇÃO 2: K-Nearest Neighbors Classification")
    print("="*80)
    
    print(f"Treinando KNN (k={k_neighbors})...")
    knn = train_knn_classifier(embeddings, labels, n_neighbors=k_neighbors, metric='cosine')
    
    print("Avaliando KNN...")
    knn_metrics = evaluate_knn(knn, embeddings, labels, class_names=class_names)
    
    print("\nResultados KNN:")
    print(f"  Accuracy: {knn_metrics['accuracy']:.4f}")
    print(f"  Macro F1: {knn_metrics['macro_f1']:.4f}")
    print(f"  Weighted F1: {knn_metrics['weighted_f1']:.4f}")
    print(f"  Macro Precision: {knn_metrics['macro_precision']:.4f}")
    print(f"  Macro Recall: {knn_metrics['macro_recall']:.4f}")
    print()
    
    # Avaliação 3: Métricas por classe
    print("="*80)
    print("AVALIAÇÃO 3: Métricas por Classe")
    print("="*80)
    class_metrics = compute_class_wise_metrics(labels, knn_metrics['predictions'], class_names)
    print(class_metrics.to_string(index=False))
    class_metrics.to_csv(output_dir / 'class_metrics.csv', index=False)
    print(f"\nMétricas salvas em {output_dir / 'class_metrics.csv'}\n")
    
    # Avaliação 4: Comparação de valores de k
    print("="*80)
    print("AVALIAÇÃO 4: Comparação de Valores de K")
    print("="*80)
    k_comparison = compare_k_values(embeddings, labels, k_values=[1, 3, 5, 7, 9, 11, 15], metric='cosine')
    print(k_comparison.to_string(index=False))
    k_comparison.to_csv(output_dir / 'k_comparison.csv', index=False)
    print(f"\nComparação salva em {output_dir / 'k_comparison.csv'}\n")
    
    # Visualizações
    print("="*80)
    print("GERANDO VISUALIZAÇÕES")
    print("="*80)
    
    print("\nMatriz de Confusão...")
    plot_confusion_matrix(
        knn_metrics['confusion_matrix'],
        class_names,
        title=f'SAMI Confusion Matrix (k={k_neighbors})',
        save_path=str(output_dir / 'confusion_matrix.png')
    )
    
    print("\nt-SNE...")
    plot_tsne(
        embeddings,
        labels,
        class_names,
        title='SAMI t-SNE Embedding Visualization',
        save_path=str(output_dir / 't-sne_visualization.png'),
        perplexity=min(30, len(embeddings) // 5)
    )
    
    # Salvar relatório
    print("\nSalvando relatório...")
    all_metrics = {**knn_metrics, 'cbir': cbir_results}
    save_evaluation_report(all_metrics, class_names, str(output_dir / 'evaluation_report.txt'))
    
    # Salvar embeddings
    save_embeddings(embeddings, labels.tolist(), image_paths, str(output_dir / 'embeddings.npz'))
    
    print("\n" + "="*80)
    print("AVALIAÇÃO COMPLETA!")
    print("="*80)
    print(f"\nResultados salvos em: {output_dir}")
    print("\nArquivos gerados:")
    print("  - class_metrics.csv")
    print("  - k_comparison.csv")
    print("  - confusion_matrix.png")
    print("  - t-sne_visualization.png")
    print("  - evaluation_report.txt")
    print("  - embeddings.npz")
    print("="*80)
    
    return {
        'embeddings': embeddings,
        'labels': labels,
        'knn_metrics': knn_metrics,
        'cbir_results': cbir_results,
        'class_metrics': class_metrics,
        'k_comparison': k_comparison
    }

---
## 9. Exemplo de Uso <a id="9-example"></a>

Demonstração básica do uso do SAMI.

In [None]:
# ============================================================================
# EXEMPLO 1: Avaliação Completa
# ============================================================================
# Descomente e configure os caminhos para executar

# results = run_full_evaluation(
#     data_path='./imagefolder_cambrian_sponges',
#     model_path=None,  # ou './path/to/weights.pth'
#     model_arch='vit_small',
#     output_dir='./results',
#     k_neighbors=7,
#     batch_size=32
# )

In [None]:
# ============================================================================
# EXEMPLO 2: Clustering Analysis
# ============================================================================
# Descomente e configure para executar análise de clustering

# data_path = './imagefolder_cambrian_sponges'
# output_dir = Path('./clustering_results')
# output_dir.mkdir(exist_ok=True)

# # Carregar dados
# images_tensor, image_paths, labels = load_dataset_images(data_path, img_size=224)

# # Carregar modelo
# model = vit_small(patch_size=16)
# model.eval()
# model = model.to(DEVICE)

# # Extrair features
# embeddings = extract_features(model, images_tensor, batch_size=32, device=DEVICE)
# embeddings = embeddings / np.linalg.norm(embeddings, axis=1, keepdims=True)

# # t-SNE
# coords_2d = compute_tsne_clustering(embeddings, perplexity=30)

# # K-Means
# results, df = kmeans_clustering(embeddings, [3, 5, 7, 10], coords_2d, output_dir)

# # DBSCAN
# dbscan_clustering(embeddings, coords_2d, output_dir)

# # Hierarchical
# hierarchical_clustering(embeddings, coords_2d, output_dir)

In [None]:
# ============================================================================
# EXEMPLO 3: Multi-Scale Patch Clustering
# ============================================================================
# Descomente e configure para análise de patches

# data_path = Path('./imagefolder_cambrian_sponges')
# output_dir = Path('./patch_clustering_results')
# output_dir.mkdir(exist_ok=True)

# window_sizes = [64, 128, 256]
# stride = 32
# max_patches_per_image = 50
# min_content_ratio = 0.7
# n_clusters = 10

# # Listar imagens
# class_folders = [d for d in data_path.iterdir() if d.is_dir()]
# image_paths = []
# class_names = []
# for class_folder in class_folders:
#     files = list(class_folder.glob('*.jpg')) + list(class_folder.glob('*.png'))
#     for f in files:
#         image_paths.append(f)
#         class_names.append(class_folder.name)

# # Extrair patches
# all_patches = extract_all_patches(
#     image_paths, class_names, window_sizes, 
#     stride, max_patches_per_image, min_content_ratio
# )

# # Modelo
# model = vit_small(patch_size=16)
# model.eval()
# model = model.to(DEVICE)

# # Features
# features = extract_features_from_patches(all_patches, model, DEVICE, batch_size=32)
# metadata_list = [m for _, m in all_patches]

# # Clustering
# labels, kmeans, silhouette = cluster_patches(features, n_clusters, metadata_list)

# # Visualização UMAP
# visualize_umap_patches(features, labels, metadata_list, 
#                        save_path=str(output_dir / 'umap.png'),
#                        silhouette_val=silhouette)

# # Exemplos de patches
# visualize_patch_examples(all_patches, labels, metadata_list, 
#                          save_path=str(output_dir / 'patch_examples.png'))

In [None]:
# ============================================================================
# EXEMPLO 4: Inferência em Imagem Individual
# ============================================================================

def predict_single_image(image_path, model, embeddings_db, labels_db, class_names, k=5):
    """
    Prediz a classe de uma imagem individual usando CBIR
    
    Args:
        image_path: Caminho para a imagem
        model: Modelo ViT
        embeddings_db: Embeddings do banco de dados
        labels_db: Labels do banco de dados
        class_names: Nomes das classes
        k: Número de vizinhos
    """
    # Carregar e preprocessar
    img = load_image(image_path)
    if img is None:
        print(f"Erro ao carregar {image_path}")
        return None
    
    img_tensor = preprocess_image(img, img_size=224)
    img_batch = img_tensor.unsqueeze(0).to(DEVICE)
    
    # Extrair features
    model.eval()
    with torch.no_grad():
        query_features = model(img_batch).cpu().numpy()
    
    # Normalizar
    query_features = query_features / np.linalg.norm(query_features, axis=1, keepdims=True)
    embeddings_norm = embeddings_db / np.linalg.norm(embeddings_db, axis=1, keepdims=True)
    
    # Encontrar vizinhos
    nn_model = NearestNeighbors(n_neighbors=k, metric='euclidean')
    nn_model.fit(embeddings_norm)
    distances, indices = nn_model.kneighbors(query_features)
    
    # Votação
    neighbor_labels = labels_db[indices[0]]
    unique, counts = np.unique(neighbor_labels, return_counts=True)
    predicted_class = unique[np.argmax(counts)]
    
    print(f"\nResultados para: {image_path}")
    print(f"Classe predita: {class_names[predicted_class]}")
    print(f"\nTop {k} vizinhos mais próximos:")
    for i, (dist, idx) in enumerate(zip(distances[0], indices[0])):
        print(f"  {i+1}. {class_names[labels_db[idx]]} (distância: {dist:.4f})")
    
    return predicted_class, distances, indices


# Exemplo de uso:
# predicted, dists, idxs = predict_single_image(
#     './test_image.jpg',
#     model,
#     embeddings,
#     labels,
#     class_names,
#     k=5
# )

---
## Estrutura de Diretórios Esperada

```
imagefolder_cambrian_sponges/
├── Archaeocyatha_sp1/
│   ├── specimen_001.jpg
│   ├── specimen_002.jpg
│   └── ...
├── Porifera_sp2/
│   └── ...
└── ...
```

---

## Notas

- Para usar pesos pré-treinados, forneça o caminho em `model_path`
- UMAP requer `pip install umap-learn`
- Para datasets grandes, ajuste `batch_size` conforme memória disponível
- O Silhouette Score varia de -1 a 1, com >0.5 indicando boa separação de clusters