In [192]:
import torch
import importlib
from torch.utils.data import DataLoader
from torchvision import datasets, transforms
importlib.reload(__import__('helpers'))
from helpers import save_experiment_results

In [None]:

transform = transforms.Compose([
    transforms.Resize((224, 224)),
    transforms.ToTensor(),
    transforms.Normalize(
        mean=[0.485, 0.456, 0.406],   
        std=[0.229, 0.224, 0.225]
    )
])

In [None]:
train_dataset = datasets.ImageFolder(
    root='/content/dataset/ham10000/train',
    transform=transform
)

test_dataset = datasets.ImageFolder(
    root='/content/dataset/ham10000/test',
    transform=transform
)

In [195]:
train_loader = DataLoader(
    train_dataset,
    batch_size=25,
    shuffle=False
)

test_loader = DataLoader(
    test_dataset,
    batch_size=75,
    shuffle=False
)


In [196]:
print("Classes:", train_dataset.classes)
print("Número de imagens treino:", len(train_dataset))
print("Número de imagens teste:", len(test_dataset))


Classes: ['akiec', 'bcc', 'bkl', 'mel', 'nv']
Número de imagens treino: 50
Número de imagens teste: 250


In [None]:
# Estratégia baseada em Nurgazin et al. (2023): ViT + ProtoNet
import torch
import torch.nn as nn
import torch.nn.functional as F
from sklearn.metrics import accuracy_score, classification_report
import timm  # Biblioteca para Vision Transformers

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

experiment_name = "deit_cosine_similarity"


# 1. Carregar Vision Transformer pré-treinado (melhor que ResNet para FSL)
print("Carregando DeiT...")
model = timm.create_model('deit_base_distilled_patch16_224', pretrained=True, num_classes=0)
model.eval()
model = model.to(device)

# 2. Extrair features
def extract_features_vit(loader, model):
    features = []
    labels = []
    
    with torch.no_grad():
        for imgs, lbls in loader:
            imgs = imgs.to(device)
            feats = model(imgs)  # ViT retorna features globais
            features.append(feats)
            labels.append(lbls.to(device))
    
    return torch.cat(features), torch.cat(labels)

print("Extraindo features com ViT...")
support_features, support_labels = extract_features_vit(train_loader, model)
query_features, query_labels = extract_features_vit(test_loader, model)

# 3. Normalização L2 (crucial para similaridade métrica)
support_features = F.normalize(support_features, p=2, dim=1)
query_features = F.normalize(query_features, p=2, dim=1)

# 4. Calcular protótipos com normalização
num_classes = len(train_dataset.classes)
prototypes = torch.zeros(num_classes, support_features.size(1)).to(device)

for c in range(num_classes):
    mask = (support_labels == c)
    if mask.sum() > 0:
        class_features = support_features[mask]
        prototypes[c] = class_features.mean(dim=0)
        # Normalizar cada protótipo
        prototypes[c] = F.normalize(prototypes[c].unsqueeze(0), p=2, dim=1).squeeze(0)

# 5. Classificação por SIMILARIDADE COSINE (produto escalar com vetores normalizados)
# Mais adequado que distância euclidiana para features normalizadas
similarities = torch.mm(query_features, prototypes.t())  
predictions_deit = similarities.argmax(dim=1)

# 6. Avaliar
accuracy = (predictions_deit == query_labels).float().mean().item()
print(f'\n{"="*60}')
print(f'Acurácia DeiT + Prototypical Networks: {accuracy*100:.2f}%')
print(f'{"="*60}')

print("\nRelatório de Classificação:")
print(classification_report(
    query_labels.cpu().numpy(), 
    predictions_deit.cpu().numpy(),
    target_names=train_dataset.classes,
    digits=3
))

# 7. SALVAR RESULTADOS
exp_dir = save_experiment_results(
    experiment_name=experiment_name,
    model_name="DeiT Base Distilled Patch16 224",
    metric_name="Similaridade Cosine",
    normalization="Normalizacao L2",
    y_true=query_labels.cpu().numpy(),
    y_pred=predictions_deit.cpu().numpy(),
    class_names=train_dataset.classes,
    n_train=len(train_dataset),
    n_test=len(test_dataset),
    device=device
)

Carregando DeiT...
Extraindo features com ViT...

Acurácia DeiT + Prototypical Networks: 50.00%

Relatório de Classificação:
              precision    recall  f1-score   support

       akiec      0.562     0.540     0.551        50
         bcc      0.550     0.660     0.600        50
         bkl      0.256     0.200     0.225        50
         mel      0.434     0.460     0.447        50
          nv      0.640     0.640     0.640        50

    accuracy                          0.500       250
   macro avg      0.489     0.500     0.492       250
weighted avg      0.489     0.500     0.492       250


Resultados salvos em: /content/drive/MyDrive/pdi/resultados/deit_cosine_similarity_5way_10shot_v2
Configuração: 5-way 10-shot
Arquivos gerados:
  - config.json
  - metrics.json
  - report.txt
  - confusion_matrix.png
  - accuracy_per_class.png
  - predictions.json



In [198]:
# Voltar para ResNet50 + Distância Euclidiana + SALVAR RESULTADOS
import torch
import torch.nn as nn
import torch.nn.functional as F
from sklearn.metrics import accuracy_score, classification_report, confusion_matrix
import torchvision.models as models

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

# Nome do experimento
experiment_name = "resnet50_euclidiana"


# 1. Carregar ResNet50 pré-treinada
print("Carregando ResNet50...")
model = models.resnet50(pretrained=True)
model.fc = nn.Identity()  # Remove última camada (só features)
model.eval()
model = model.to(device)

# 2. Extrair features
def extract_features_resnet(loader, model):
    features = []
    labels = []
    
    with torch.no_grad():
        for imgs, lbls in loader:
            imgs = imgs.to(device)
            feats = model(imgs)
            features.append(feats)
            labels.append(lbls.to(device))
    
    return torch.cat(features), torch.cat(labels)

print("Extraindo features com ResNet50...")
support_features, support_labels = extract_features_resnet(train_loader, model)
query_features, query_labels = extract_features_resnet(test_loader, model)


# 3. Calcular protótipos (SEM normalização L2)
num_classes = len(train_dataset.classes)
prototypes = torch.zeros(num_classes, support_features.size(1)).to(device)

for c in range(num_classes):
    mask = (support_labels == c)
    if mask.sum() > 0:
        class_features = support_features[mask]
        prototypes[c] = class_features.mean(dim=0)  # Média simples

# 4. Classificação por DISTÂNCIA EUCLIDIANA
distances = torch.cdist(query_features, prototypes)  # Distância euclidiana
predictions_resnet = distances.argmin(dim=1)  # Classe com menor distância

# 5. Avaliar
accuracy = (predictions_resnet == query_labels).float().mean().item()
y_true = query_labels.cpu().numpy()
y_pred = predictions_resnet.cpu().numpy()

print(f'\n{"="*60}')
print(f'Acurácia ResNet50 + Distância Euclidiana: {accuracy*100:.2f}%')
print(f'{"="*60}')


print("\nRelatório de Classificação:")
print(classification_report(
    y_true, 
    y_pred,
    target_names=train_dataset.classes,
    digits=3
))

# 7. SALVAR RESULTADOS
exp_dir = save_experiment_results(
    experiment_name=experiment_name,
    model_name="ResNet50",
    metric_name="Distancia Euclidiana",
    normalization="Sem normalizacao L2",
    y_true=y_true,
    y_pred=y_pred,
    class_names=train_dataset.classes,
    n_train=len(train_dataset),
    n_test=len(test_dataset),
    device=device
)


Carregando ResNet50...




Extraindo features com ResNet50...

Acurácia ResNet50 + Distância Euclidiana: 46.80%

Relatório de Classificação:
              precision    recall  f1-score   support

       akiec      0.580     0.580     0.580        50
         bcc      0.492     0.620     0.549        50
         bkl      0.300     0.180     0.225        50
         mel      0.354     0.340     0.347        50
          nv      0.525     0.620     0.569        50

    accuracy                          0.468       250
   macro avg      0.450     0.468     0.454       250
weighted avg      0.450     0.468     0.454       250


Resultados salvos em: /content/drive/MyDrive/pdi/resultados/resnet50_euclidiana_5way_10shot_v4
Configuração: 5-way 10-shot
Arquivos gerados:
  - config.json
  - metrics.json
  - report.txt
  - confusion_matrix.png
  - accuracy_per_class.png
  - predictions.json

