In [41]:
import torch
from torch.utils.data import DataLoader
from torchvision import datasets, transforms

In [42]:
transform = transforms.Compose([
    transforms.Resize((224, 224)),
    transforms.ToTensor(),
    transforms.Normalize(
        mean=[0.485, 0.456, 0.406],   
        std=[0.229, 0.224, 0.225]
    )
])

In [43]:
train_dataset = datasets.ImageFolder(
    root='/content/dataset/ham10000/train',
    transform=transform
)

test_dataset = datasets.ImageFolder(
    root='/content/dataset/ham10000/test',
    transform=transform
)

In [44]:
train_loader = DataLoader(
    train_dataset,
    batch_size=8,
    shuffle=True
)

test_loader = DataLoader(
    test_dataset,
    batch_size=8,
    shuffle=False
)


In [45]:
print("Classes:", train_dataset.classes)
print("Número de imagens treino:", len(train_dataset))
print("Número de imagens teste:", len(test_dataset))


Classes: ['akiec', 'bcc', 'bkl', 'mel', 'nv']
Número de imagens treino: 25
Número de imagens teste: 75


In [51]:
# Estratégia baseada em Nurgazin et al. (2023): ViT + ProtoNet
import torch
import torch.nn as nn
import torch.nn.functional as F
from sklearn.metrics import accuracy_score, classification_report
import timm  # Biblioteca para Vision Transformers

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

# 1. Carregar Vision Transformer pré-treinado (melhor que ResNet para FSL)
print("Carregando Vision Transformer...")
model = timm.create_model('vit_base_patch16_224', pretrained=True, num_classes=0)  # num_classes=0 remove última camada
model.eval()
model = model.to(device)

# 2. Extrair features
def extract_features_vit(loader, model):
    features = []
    labels = []
    
    with torch.no_grad():
        for imgs, lbls in loader:
            imgs = imgs.to(device)
            feats = model(imgs)  # ViT retorna features globais
            features.append(feats)
            labels.append(lbls.to(device))
    
    return torch.cat(features), torch.cat(labels)

print("Extraindo features com ViT...")
support_features, support_labels = extract_features_vit(train_loader, model)
query_features, query_labels = extract_features_vit(test_loader, model)

# 3. Normalização L2 (crucial para similaridade métrica)
support_features = F.normalize(support_features, p=2, dim=1)
query_features = F.normalize(query_features, p=2, dim=1)

# 4. Calcular protótipos com normalização
num_classes = len(train_dataset.classes)
prototypes = torch.zeros(num_classes, support_features.size(1)).to(device)

for c in range(num_classes):
    mask = (support_labels == c)
    if mask.sum() > 0:
        class_features = support_features[mask]
        prototypes[c] = class_features.mean(dim=0)
        # Normalizar cada protótipo
        prototypes[c] = F.normalize(prototypes[c].unsqueeze(0), p=2, dim=1).squeeze(0)

# 5. Classificação por SIMILARIDADE COSINE (produto escalar com vetores normalizados)
# Mais adequado que distância euclidiana para features normalizadas
similarities = torch.mm(query_features, prototypes.t())  
predictions = similarities.argmax(dim=1)

# 6. Avaliar
accuracy = (predictions == query_labels).float().mean().item()
print(f'\n{"="*60}')
print(f'Acurácia ViT + Prototypical Networks: {accuracy*100:.2f}%')
print(f'{"="*60}')

print("\nRelatório de Classificação:")
print(classification_report(
    query_labels.cpu().numpy(), 
    predictions.cpu().numpy(),
    target_names=train_dataset.classes,
    digits=3
))

Carregando Vision Transformer...


Error while fetching `HF_TOKEN` secret value from your vault: 'Requesting secret HF_TOKEN timed out. Secrets can only be fetched when running from the Colab UI.'.
You are not authenticated with the Hugging Face Hub in this notebook.
If the error persists, please let us know by opening an issue on GitHub (https://github.com/huggingface/huggingface_hub/issues/new).


model.safetensors:   0%|          | 0.00/346M [00:00<?, ?B/s]

Extraindo features com ViT...

Acurácia ViT + Prototypical Networks: 26.67%

Relatório de Classificação:
              precision    recall  f1-score   support

       akiec      0.400     0.267     0.320        15
         bcc      0.318     0.467     0.378        15
         bkl      0.286     0.133     0.182        15
         mel      0.238     0.333     0.278        15
          nv      0.133     0.133     0.133        15

    accuracy                          0.267        75
   macro avg      0.275     0.267     0.258        75
weighted avg      0.275     0.267     0.258        75

