In [38]:
import torch
import torch.nn as nn
import torchvision
import sys
import os
import numpy as np
import collections

sys.path.insert(0, os.path.abspath('src'))

from src.constants import device, BATCH_SIZE, WORKERS, PREFETCH
from src.models import ModernCNN, ImprovedCNN, EfficientCNN, AttentionCNN, evaluate_model
from src.transformations import test_transform, train_transform

## 1 - Carregamento de Modelos Pré-treinados

#### Parâmetro

    • model_path: Caminho para o ficheiro de checkpoint do modelo pré-treinado.

Carrega o resultado de um modelo pré-treinado a partir de um ficheiro de checkpoint.


In [39]:
def load_trained_model(model_path):
    checkpoint = torch.load(model_path, map_location=device)
    model_type = checkpoint['model_type']
    num_classes = checkpoint['num_classes']

    if model_type == 'AttentionCNN':
        model = AttentionCNN(num_classes=num_classes)
    elif model_type == 'EfficientCNN':
        model = EfficientCNN(num_classes=num_classes)
    elif model_type == 'ImprovedCNN':
        model = ImprovedCNN(num_classes=num_classes)
    elif model_type == 'ModernCNN':
        model = ModernCNN(num_classes=num_classes) 
    else:
        raise ValueError(f"Modelo desconhecido: '{model_type}'")

    model.load_state_dict(checkpoint['model_state_dict'])
    model = model.to(device)
    model.eval()

    return model

## 2 - Preparação de Dados

#### Estrutura e Configuração

    • Configuração do conjunto de treino

    • Configuração do conjunto de teste

In [42]:
# Load dataset with augmentations
data_path = 'data/'  # Path to the data folder

# Load training dataset with augmentation
train_dataset = torchvision.datasets.ImageFolder(
    root=data_path + 'train_images',
    transform=train_transform
)
train_loader = torch.utils.data.DataLoader(
    train_dataset, 
    batch_size=BATCH_SIZE, 
    shuffle=True,
    num_workers=WORKERS,
    pin_memory=True,
    prefetch_factor=PREFETCH,
    persistent_workers=True
)

# Load test dataset
test_dataset = torchvision.datasets.ImageFolder(
    root=data_path + 'test_images',
    transform=test_transform
)
test_loader = torch.utils.data.DataLoader(
    test_dataset, 
    batch_size=BATCH_SIZE, 
    shuffle=False,
    num_workers=WORKERS,
    pin_memory=True,
    prefetch_factor=PREFETCH,
    persistent_workers=True
)




## 3 - Definição de Modelos por Configuração

#### Estrutura e Organização

1. **Organização por tipo de augmentation**:
   - Seis configurações distintas são definidas: advanced, aggressive, color, default, basic, geometric
   - Cada configuração corresponde a uma estratégia específica de data augmentation utilizada durante o treino dos modelos
   - Esta organização permite avaliar o impacto de diferentes técnicas de augmentation no desempenho dos ensembles

2. **Consistência arquitetural**:
   - Para cada configuração, são incluídas as mesmas quatro arquiteturas: AttentionCNN, EfficientCNN, ImprovedCNN, ModernCNN
   - Esta consistência permite comparações diretas entre diferentes configurações de augmentation
   - A ordem das arquiteturas é mantida consistente em todas as listas, facilitando o processamento

3. **Uso de ficheiro `.pt`**:
   - O sufixo `.pt` indica que são ficheiros de checkpoint PyTorch

4. **Constante N_MODELS**:
   - Define explicitamente o número de modelos em cada configuração (4)
   - Esta constante será utilizada posteriormente para iterações e indexação
   - A definição explícita melhora a manutenibilidade do código, evitando números mágicos


In [43]:
models_advanced = [
    "results/AttentionCNN_advanced.pt",
    "results/EfficientCNN_advanced.pt",
    "results/ImprovedCNN_advanced.pt",
    "results/ModernCNN_advanced.pt",
]
models_aggressive = [
    "results/AttentionCNN_aggressive.pt",
    "results/EfficientCNN_aggressive.pt",
    "results/ImprovedCNN_aggressive.pt",
    "results/ModernCNN_aggressive.pt",
]
models_color = [
    "results/AttentionCNN_color.pt",
    "results/EfficientCNN_color.pt",
    "results/ImprovedCNN_color.pt",
    "results/ModernCNN_color.pt",
]

models_default = [
    "results/AttentionCNN_default.pt",
    "results/EfficientCNN_default.pt",
    "results/ImprovedCNN_default.pt",
    "results/ModernCNN_default.pt",
]

models_basic = [
    "results/AttentionCNN_basic.pt",
    "results/EfficientCNN_basic.pt",
    "results/ImprovedCNN_basic.pt",
    "results/ModernCNN_basic.pt",
]

models_geometric = [
    "results/AttentionCNN_geometric.pt",
    "results/EfficientCNN_geometric.pt",
    "results/ImprovedCNN_geometric.pt",
    "results/ModernCNN_geometric.pt",
]

N_MODELS = 4

## 4 - Load models

#### Load de todos os modelos para cada augmetation.

#### Cálculo da accuracy de todos os modelo para cada augmentation

In [44]:
loaded_models_advanced = [load_trained_model(path) for path in models_advanced]
acc_advanced = [evaluate_model(model, test_loader) for model in loaded_models_advanced]

loaded_models_aggressive = [load_trained_model(path) for path in models_aggressive]
acc_aggressive = [evaluate_model(model, test_loader) for model in loaded_models_aggressive]

loaded_models_color = [load_trained_model(path) for path in models_color]
acc_color = [evaluate_model(model, test_loader) for model in loaded_models_color]

loaded_models_basic = [load_trained_model(path) for path in models_basic]
acc_basic = [evaluate_model(model, test_loader) for model in loaded_models_basic]

loaded_models_default = [load_trained_model(path) for path in models_default]
acc_default = [evaluate_model(model, test_loader) for model in loaded_models_default]

loaded_models_geometric = [load_trained_model(path) for path in models_geometric]
acc_geometric = [evaluate_model(model, test_loader) for model in loaded_models_geometric]

## 5 - Get Labels, Predictions and Logits

In [45]:
def get_labels_logits_and_preds(models):

    with torch.no_grad():
        logits = [[] for _ in range(N_MODELS)]
        labels = []

        for images, labs in test_loader:

            images = images.to(device)
            labels.extend(labs)
            
            for i in range(N_MODELS):
                logits[i].extend(models[i](images).cpu())


    return labels, logits

labels_advanced, logits_advanced = get_labels_logits_and_preds(loaded_models_advanced)
labels_aggressive, logits_aggressive = get_labels_logits_and_preds(loaded_models_aggressive)
labels_color, logits_color = get_labels_logits_and_preds(loaded_models_color)
labels_basic, logits_basic = get_labels_logits_and_preds(loaded_models_basic)
labels_default, logits_default = get_labels_logits_and_preds(loaded_models_default)
labels_geometric, logits_geometric = get_labels_logits_and_preds(loaded_models_geometric)

In [46]:
preds_advanced = [[] for _ in range(len(labels_advanced))]
preds_aggressive = [[] for _ in range(len(labels_aggressive))]
preds_color = [[] for _ in range(len(labels_color))]
preds_basic = [[] for _ in range(len(labels_basic))]
preds_default = [[] for _ in range(len(labels_default))]
preds_geometric = [[] for _ in range(len(labels_geometric))]

for index in range(len(labels_advanced)):
    preds_advanced[index] = [np.argmax(logits_advanced[m][index].cpu().numpy()) for m in range(N_MODELS)]
    preds_aggressive[index] = [np.argmax(logits_aggressive[m][index].cpu().numpy()) for m in range(N_MODELS)]
    preds_color[index] = [np.argmax(logits_color[m][index].cpu().numpy()) for m in range(N_MODELS)]
    preds_basic[index] = [np.argmax(logits_basic[m][index].cpu().numpy()) for m in range(N_MODELS)]
    preds_default[index] = [np.argmax(logits_default[m][index].cpu().numpy()) for m in range(N_MODELS)]
    preds_geometric[index] = [np.argmax(logits_geometric[m][index].cpu().numpy()) for m in range(N_MODELS)]

In [47]:
def get_class_from_sum_of_logits(logits):

    sum_logits = []

    for i in range(len(logits[0])):

        log = logits[0][i]
        for m in range(1, N_MODELS):
            log = np.add(log, logits[m][i])
        sum_logits.append(np.argmax(log))
    return(sum_logits)
    
class_logits_advanced = get_class_from_sum_of_logits(logits_advanced)
class_logits_aggressive = get_class_from_sum_of_logits(logits_aggressive)
class_logits_color = get_class_from_sum_of_logits(logits_color)
class_logits_basic = get_class_from_sum_of_logits(logits_basic)
class_logits_default = get_class_from_sum_of_logits(logits_default)
class_logits_geometric = get_class_from_sum_of_logits(logits_geometric)

  log = np.add(log, logits[m][i])


## 6 - Voto majoritário

#### Usa as previsões de cada modelo para avaliar os resultados da aplicação do conjunto de modelos

----

#### Voto Majoritário:

    - Se mais de metade dos modelos acertar a previsão, então é considerado um acerto para o resultado final;
    
    - Se mais de metade dos modelos errar a previsão, então é considerado um erro para o resultado final;

    - Em caso de empate, então é considerado um erro para o resultado final.


In [48]:
def get_stats(labels, class_preds, class_logits):

    all_correct = 0
    all_incorrect = 0
    maj_vote = 0
    maj_wrong = 0
    tie = 0
    count = 0

    for k in range(len(labels)):

        counter = collections.Counter(class_preds[k])
        if len(counter) == 1:
            if counter.most_common(1)[0][0] == labels[k]:
                all_correct += 1
            else:
                all_incorrect += 1
        else:
            aux = counter.most_common(2)
            if aux[0][1] > aux[1][1] and aux[0][0] == labels[k]:
                maj_vote += 1
            if aux[0][1] > aux[1][1] and aux[0][0] != labels[k]:
                maj_wrong += 1
            elif aux[0][1] == aux[1][1]:
                tie += 1

        count += 1 
        
    return [count, all_correct, all_incorrect, maj_vote, tie, maj_wrong]

## 7 - Resultados

| Configuração | Acurácia | Melhor Aspecto |
|--------------|----------|----------------|
| Advanced     | 97.03%   | Maior consenso (92.3% concordância total) |
| Color        | 97.05%   | Ligeiramente superior ao Advanced |
| Basic        | 96.63%   | Boa performance com *augmentation* simples |
| Geometric    | 95.46%   | Beneficia de transformações espaciais |
| Default      | 94.89%   | *Baseline* sólido |
| Aggressive   | 92.26%   | Maior diversidade, mais empates |


In [None]:
def print_stats(labels, class_preds, class_logits, type="Default"):
    res = get_stats(labels, class_preds, class_logits)
    print("="*80)
    print(f"{type} augmentation Ensemble Results")
    print('total: ', res[0])
    print('All correct: ', res[1])
    print('All incorrect: ', res[2])
    print('Majority correct: ', res[3])
    print('Tie Vote: ', res[4])
    print('Majority Wrong: ', res[5])
    print('Percentage right (all correct + majority correct): ', (res[1]+res[3])/res[0])
    print("="*80)

print_stats(labels_advanced, preds_advanced, class_logits_advanced, type="Advanced")
print_stats(labels_aggressive, preds_aggressive, class_logits_aggressive, type="Aggressive")
print_stats(labels_color, preds_color, class_logits_color, type="Color")
print_stats(labels_basic, preds_basic, class_logits_basic, type="Basic")
print_stats(labels_default, preds_default, class_logits_default, type="Default")
print_stats(labels_geometric, preds_geometric, class_logits_geometric, type="Geometric")

Advanced Argumentation Ensemble Results
total:  12630
All correct:  11664
All incorrect:  79
Majority correct:  591
Tie Vote:  170
Majority Wrong:  126
Percentage right (all correct + majority correct):  0.9703087885985748
Aggressive Argumentation Ensemble Results
total:  12630
All correct:  10391
All incorrect:  188
Majority correct:  1262
Tie Vote:  394
Majority Wrong:  395
Percentage right (all correct + majority correct):  0.9226444972288202
Color Argumentation Ensemble Results
total:  12630
All correct:  11664
All incorrect:  72
Majority correct:  593
Tie Vote:  164
Majority Wrong:  137
Percentage right (all correct + majority correct):  0.9704671417260491
Basic Argumentation Ensemble Results
total:  12630
All correct:  11481
All incorrect:  84
Majority correct:  724
Tie Vote:  213
Majority Wrong:  128
Percentage right (all correct + majority correct):  0.9663499604117182
Default Argumentation Ensemble Results
total:  12630
All correct:  10956
All incorrect:  80
Majority correct: 