In [64]:
import torch
import importlib
from torch.utils.data import DataLoader
from torchvision import datasets, transforms
importlib.reload(__import__('helpers'))
from helpers import save_fewshot_results

In [28]:
CONFIG = {
    'n_way': 3,           # N√∫mero de classes por epis√≥dio
    'n_shot': 5,          # Exemplos de treino por classe
    'n_query': 15,        # Exemplos de teste por classe
    'n_episodes': 200,    # N√∫mero de epis√≥dios para avalia√ß√£o
    'model_name': 'deit_base_distilled_patch16_224',
    'device': 'cuda' if torch.cuda.is_available() else 'cpu'
}

In [29]:
transform = transforms.Compose([
    transforms.Resize((224, 224)),
    transforms.ToTensor(),
    transforms.Normalize(
        mean=[0.485, 0.456, 0.406],   
        std=[0.229, 0.224, 0.225]
    )
])

In [30]:
from torchvision import datasets
from collections import defaultdict

dataset = datasets.ImageFolder(
    root='/content/drive/MyDrive/pdi/dataset/ham10000/all',
    transform=transform
)

# mapa classe ‚Üí √≠ndices
class_to_indices = defaultdict(list)
for idx, (_, label) in enumerate(dataset):
    class_to_indices[label].append(idx)


In [31]:
import random
import torch
from torch.utils.data import Subset, DataLoader

def create_episode(
    class_to_indices,
    n_way=5,
    n_shot=10,
    n_query=15
):
    classes = random.sample(list(class_to_indices.keys()), n_way)

    support_idx = []
    query_idx = []

    for c in classes:
        indices = random.sample(
            class_to_indices[c],
            n_shot + n_query
        )
        support_idx += indices[:n_shot]
        query_idx   += indices[n_shot:]

    return support_idx, query_idx, classes


In [32]:
def get_episode_loaders(dataset, support_idx, query_idx):
    support_loader = DataLoader(
        Subset(dataset, support_idx),
        batch_size=len(support_idx),
        shuffle=False
    )

    query_loader = DataLoader(
        Subset(dataset, query_idx),
        batch_size=len(query_idx),
        shuffle=False
    )

    return support_loader, query_loader


In [33]:
def extract_features_vit(loader, model):
    features = []
    labels = []

    model.eval()
    with torch.no_grad():
        for imgs, lbls in loader:
            imgs = imgs.to(next(model.parameters()).device)
            feats = model(imgs)      # sa√≠da do DeiT sem head
            features.append(feats)
            labels.append(lbls)

    return torch.cat(features), torch.cat(labels)


In [34]:
import torch.nn.functional as F

def evaluate_episode(
    model,
    dataset,
    class_to_indices,
    device,
    n_way=5,
    n_shot=10,
    n_query=15
):
    support_idx, query_idx, classes = create_episode(
        class_to_indices, n_way, n_shot, n_query
    )

    support_loader, query_loader = get_episode_loaders(
        dataset, support_idx, query_idx
    )

    support_features, support_labels = extract_features_vit(
        support_loader, model
    )
    query_features, query_labels = extract_features_vit(
        query_loader, model
    )

    # Guardar labels originais antes do remapeamento
    original_query_labels = query_labels.cpu().numpy()

    # normaliza√ß√£o
    support_features = F.normalize(support_features, p=2, dim=1)
    query_features   = F.normalize(query_features, p=2, dim=1)

    # remapeia r√≥tulos para [0..n_way-1]
    label_map = {c: i for i, c in enumerate(classes)}
    support_labels = torch.tensor(
        [label_map[int(l)] for l in support_labels],
        device=device
    )
    query_labels_remapped = torch.tensor(
        [label_map[int(l)] for l in query_labels],
        device=device
    )

    # prot√≥tipos
    prototypes = torch.zeros(n_way, support_features.size(1)).to(device)
    for i in range(n_way):
        prototypes[i] = support_features[support_labels == i].mean(0)
    
    # normalizar todos os prot√≥tipos
    prototypes = F.normalize(prototypes, p=2, dim=1)

    # similaridade
    sims = torch.mm(query_features, prototypes.t())
    preds_remapped = sims.argmax(dim=1)

    # Acur√°cia do epis√≥dio (com labels remapeados)
    acc = (preds_remapped == query_labels_remapped).float().mean().item()
    
    # Reverter predi√ß√µes para classes originais
    original_preds = np.array([classes[int(p)] for p in preds_remapped.cpu().numpy()])
    
    return acc, preds_remapped.cpu().numpy(), query_labels_remapped.cpu().numpy(), original_preds, original_query_labels


In [None]:
import timm
import torch

print(CONFIG['device'])
model = timm.create_model(
    CONFIG['model_name'],
    pretrained=True,
    num_classes=0
)

model = model.to(CONFIG['device'])
model.eval()


In [36]:
import numpy as np

accuracies = []
# Labels remapeados (para cada epis√≥dio independente)
all_predictions_remapped = []
all_labels_remapped = []
# Labels originais (classes reais do dataset)
all_predictions_original = []
all_labels_original = []

model.eval()

for ep in range(CONFIG['n_episodes']):
    acc, preds_remapped, labels_remapped, preds_original, labels_original = evaluate_episode(
        model,
        dataset,
        class_to_indices,
        CONFIG['device'],
        n_way=CONFIG['n_way'],
        n_shot=CONFIG['n_shot'],
        n_query=CONFIG['n_query']
    )
    accuracies.append(acc)
    # Armazenar ambas as vers√µes
    all_predictions_remapped.extend(preds_remapped)
    all_labels_remapped.extend(labels_remapped)
    all_predictions_original.extend(preds_original)
    all_labels_original.extend(labels_original)
    print(f"Epis√≥dio {ep+1}: {acc*100:.2f}%")

mean_acc = np.mean(accuracies)
std_acc  = np.std(accuracies)

print(f"\nAcur√°cia final: {mean_acc*100:.2f}% ¬± {std_acc*100:.2f}%")

Epis√≥dio 1: 82.22%
Epis√≥dio 2: 53.33%
Epis√≥dio 3: 66.67%
Epis√≥dio 4: 44.44%
Epis√≥dio 5: 62.22%
Epis√≥dio 6: 97.78%
Epis√≥dio 7: 60.00%
Epis√≥dio 8: 73.33%
Epis√≥dio 9: 68.89%
Epis√≥dio 10: 64.44%
Epis√≥dio 11: 71.11%
Epis√≥dio 12: 64.44%
Epis√≥dio 13: 84.44%
Epis√≥dio 14: 44.44%
Epis√≥dio 15: 77.78%
Epis√≥dio 16: 68.89%
Epis√≥dio 17: 68.89%
Epis√≥dio 18: 73.33%
Epis√≥dio 19: 71.11%
Epis√≥dio 20: 48.89%
Epis√≥dio 21: 73.33%
Epis√≥dio 22: 75.56%
Epis√≥dio 23: 64.44%
Epis√≥dio 24: 71.11%
Epis√≥dio 25: 77.78%
Epis√≥dio 26: 46.67%
Epis√≥dio 27: 64.44%
Epis√≥dio 28: 57.78%
Epis√≥dio 29: 64.44%
Epis√≥dio 30: 40.00%
Epis√≥dio 31: 75.56%
Epis√≥dio 32: 77.78%
Epis√≥dio 33: 57.78%
Epis√≥dio 34: 33.33%
Epis√≥dio 35: 64.44%
Epis√≥dio 36: 68.89%
Epis√≥dio 37: 57.78%
Epis√≥dio 38: 66.67%
Epis√≥dio 39: 77.78%
Epis√≥dio 40: 53.33%
Epis√≥dio 41: 71.11%
Epis√≥dio 42: 68.89%
Epis√≥dio 43: 60.00%
Epis√≥dio 44: 77.78%
Epis√≥dio 45: 53.33%
Epis√≥dio 46: 57.78%
Epis√≥dio 47: 64.44%
Epis√≥dio 48: 60.00%
E

## üìä Por que Precision = Recall = F1 = Acur√°cia?

**Problema:** Quando os labels s√£o **remapeados** para [0, 1, 2] em cada epis√≥dio:
- Epis√≥dio 1: [melanoma=0, nevo=1, ceratose=2]
- Epis√≥dio 2: [nevo=0, basalioma=1, melanoma=2]  
- Epis√≥dio 3: [ceratose=0, melanoma=1, dermatofibroma=2]

Ao agregar 200 epis√≥dios, cada slot [0, 1, 2] recebe uma **distribui√ß√£o uniforme** de todas as classes reais. Com amostragem aleat√≥ria, a matriz de confus√£o fica balanceada e **todas as m√©tricas convergem para o mesmo valor**!

**Solu√ß√£o:** Manter os **labels originais** (classes reais) para calcular m√©tricas que fazem sentido sem√¢ntico.

In [None]:
# Converter para arrays
all_predictions_remapped = np.array(all_predictions_remapped)
all_labels_remapped = np.array(all_labels_remapped)
all_predictions_original = np.array(all_predictions_original)
all_labels_original = np.array(all_labels_original)


# Obter nomes das classes do dataset
class_names = dataset.classes

# Salvar resultados usando as classes ORIGINAIS + nomes das classes
exp_dir = save_fewshot_results(
    experiment_name="deit_prototypical_fewshot",
    model_name=CONFIG['model_name'],
    metric_name="Similaridade Cosine",
    normalization="Normalizacao L2",
    accuracies=accuracies,
    n_way=CONFIG['n_way'],
    n_shot=CONFIG['n_shot'],
    n_query=CONFIG['n_query'],
    n_episodes=CONFIG['n_episodes'],
    device=CONFIG['device'],
    all_predictions=all_predictions_original, 
    all_labels=all_labels_original,             
    class_names=class_names                     
)


COMPARA√á√ÉO DE M√âTRICAS:

1. M√©tricas com labels REMAPEADOS (cada epis√≥dio = [0,1,2]):
   (Labels perdem significado sem√¢ntico - n√£o recomendado!)
   Precision (macro): 65.00%
   Recall (macro):    65.00%
   F1-Score (macro):  64.99%
   Acur√°cia:          65.00%
   ‚Üí Todas iguais! N√£o t√™m significado real.

2. M√©tricas com labels ORIGINAIS (classes reais do dataset):
   (M√©tricas por classe real - correto!)
   Precision (macro): 64.60%
   Recall (macro):    64.90%
   F1-Score (macro):  64.64%
   Acur√°cia:          65.00%
   ‚Üí Valores diferentes! Fazem sentido sem√¢ntico.

3. Acur√°cia m√©dia dos epis√≥dios (m√©todo padr√£o few-shot):
   Acur√°cia:  65.00% ¬± 12.21%
   ‚Üí M√©trica correta para few-shot epis√≥dico!

M√âTRICAS DETALHADAS POR CLASSE (Labels Originais):

Classes do dataset: ['akiec', 'bcc', 'bkl', 'df', 'mel', 'nv', 'vasc']

              precision    recall  f1-score   support

       akiec     0.7003    0.7364    0.7179      1320
         bcc     0.5820 