In [None]:
import torch
import importlib
from torch.utils.data import DataLoader
from torchvision import datasets, transforms
importlib.reload(__import__('helpers'))
from helpers import save_fewshot_results

In [None]:
CONFIG = {
    'n_way': 3,           # Número de classes por episódio
    'n_shot': 5,          # Exemplos de treino por classe
    'n_query': 15,        # Exemplos de teste por classe
    'n_episodes': 200,    # Número de episódios para avaliação
    'model_name': 'deit_base_distilled_patch16_224',
    'device': 'cuda' if torch.cuda.is_available() else 'cpu'
}

In [None]:
transform = transforms.Compose([
    transforms.Resize((224, 224)),
    transforms.ToTensor(),
    transforms.Normalize(
        mean=[0.485, 0.456, 0.406],   
        std=[0.229, 0.224, 0.225]
    )
])

In [None]:
from torchvision import datasets
from collections import defaultdict

dataset = datasets.ImageFolder(
    root='/content/drive/MyDrive/pdi/dataset/ham10000/all',
    transform=transform
)

# mapa classe → índices
class_to_indices = defaultdict(list)
for idx, (_, label) in enumerate(dataset):
    class_to_indices[label].append(idx)


In [None]:
import random
import torch
from torch.utils.data import Subset, DataLoader

def create_episode(
    class_to_indices,
    n_way=5,
    n_shot=10,
    n_query=15
):
    classes = random.sample(list(class_to_indices.keys()), n_way)

    support_idx = []
    query_idx = []

    for c in classes:
        indices = random.sample(
            class_to_indices[c],
            n_shot + n_query
        )
        support_idx += indices[:n_shot]
        query_idx   += indices[n_shot:]

    return support_idx, query_idx, classes


In [None]:
def get_episode_loaders(dataset, support_idx, query_idx):
    support_loader = DataLoader(
        Subset(dataset, support_idx),
        batch_size=len(support_idx),
        shuffle=False
    )

    query_loader = DataLoader(
        Subset(dataset, query_idx),
        batch_size=len(query_idx),
        shuffle=False
    )

    return support_loader, query_loader


In [None]:
def extract_features_vit(loader, model):
    features = []
    labels = []

    model.eval()
    with torch.no_grad():
        for imgs, lbls in loader:
            imgs = imgs.to(next(model.parameters()).device)
            feats = model(imgs)      # saída do DeiT sem head
            features.append(feats)
            labels.append(lbls)

    return torch.cat(features), torch.cat(labels)


In [None]:
import torch.nn.functional as F

def evaluate_episode(
    model,
    dataset,
    class_to_indices,
    device,
    n_way=5,
    n_shot=10,
    n_query=15
):
    support_idx, query_idx, classes = create_episode(
        class_to_indices, n_way, n_shot, n_query
    )

    support_loader, query_loader = get_episode_loaders(
        dataset, support_idx, query_idx
    )

    support_features, support_labels = extract_features_vit(
        support_loader, model
    )
    query_features, query_labels = extract_features_vit(
        query_loader, model
    )

    # Guardar labels originais antes do remapeamento
    original_query_labels = query_labels.cpu().numpy()

    # normalização
    support_features = F.normalize(support_features, p=2, dim=1)
    query_features   = F.normalize(query_features, p=2, dim=1)

    # remapeia rótulos para [0..n_way-1]
    label_map = {c: i for i, c in enumerate(classes)}
    support_labels = torch.tensor(
        [label_map[int(l)] for l in support_labels],
        device=device
    )
    query_labels_remapped = torch.tensor(
        [label_map[int(l)] for l in query_labels],
        device=device
    )

    # protótipos
    prototypes = torch.zeros(n_way, support_features.size(1)).to(device)
    for i in range(n_way):
        prototypes[i] = support_features[support_labels == i].mean(0)
    
    # normalizar todos os protótipos
    prototypes = F.normalize(prototypes, p=2, dim=1)

    # similaridade
    sims = torch.mm(query_features, prototypes.t())
    preds_remapped = sims.argmax(dim=1)

    # Acurácia do episódio (com labels remapeados)
    acc = (preds_remapped == query_labels_remapped).float().mean().item()
    
    # Reverter predições para classes originais
    original_preds = np.array([classes[int(p)] for p in preds_remapped.cpu().numpy()])
    
    return acc, preds_remapped.cpu().numpy(), query_labels_remapped.cpu().numpy(), original_preds, original_query_labels


In [None]:
import timm
import torch

print(CONFIG['device'])
model = timm.create_model(
    CONFIG['model_name'],
    pretrained=True,
    num_classes=0
)

model = model.to(CONFIG['device'])
model.eval()


In [None]:
import numpy as np

accuracies = []
# Labels remapeados (para cada episódio independente)
all_predictions_remapped = []
all_labels_remapped = []
# Labels originais (classes reais do dataset)
all_predictions_original = []
all_labels_original = []

model.eval()

for ep in range(CONFIG['n_episodes']):
    acc, preds_remapped, labels_remapped, preds_original, labels_original = evaluate_episode(
        model,
        dataset,
        class_to_indices,
        CONFIG['device'],
        n_way=CONFIG['n_way'],
        n_shot=CONFIG['n_shot'],
        n_query=CONFIG['n_query']
    )
    accuracies.append(acc)
    # Armazenar ambas as versões
    all_predictions_remapped.extend(preds_remapped)
    all_labels_remapped.extend(labels_remapped)
    all_predictions_original.extend(preds_original)
    all_labels_original.extend(labels_original)
    print(f"Episódio {ep+1}: {acc*100:.2f}%")

mean_acc = np.mean(accuracies)
std_acc  = np.std(accuracies)

print(f"\nAcurácia final: {mean_acc*100:.2f}% ± {std_acc*100:.2f}%")

In [None]:
# Converter para arrays
all_predictions_remapped = np.array(all_predictions_remapped)
all_labels_remapped = np.array(all_labels_remapped)
all_predictions_original = np.array(all_predictions_original)
all_labels_original = np.array(all_labels_original)


# Obter nomes das classes do dataset
class_names = dataset.classes

# Salvar resultados usando as classes ORIGINAIS + nomes das classes
exp_dir = save_fewshot_results(
    experiment_name="deit_prototypical_fewshot",
    model_name=CONFIG['model_name'],
    metric_name="Similaridade Cosine",
    normalization="Normalizacao L2",
    accuracies=accuracies,
    n_way=CONFIG['n_way'],
    n_shot=CONFIG['n_shot'],
    n_query=CONFIG['n_query'],
    n_episodes=CONFIG['n_episodes'],
    device=CONFIG['device'],
    all_predictions=all_predictions_original, 
    all_labels=all_labels_original,             
    class_names=class_names                     
)