In [1]:
"""
Pipeline completo para executar 3 modelos treinados sequencialmente
com visualiza√ß√µes, m√©tricas e salvamento de resultados em pasta 'results'
"""

import os
import torch
import torch.nn as nn
from torch.utils.data import Dataset, DataLoader
import torchvision.transforms as transforms
import torchvision.models as models
import numpy as np
import pandas as pd
import cv2
import time
import psutil
from datetime import datetime, timedelta
from pathlib import Path
import matplotlib.pyplot as plt
import seaborn as sns
from sklearn.metrics import (
    accuracy_score, f1_score, precision_score, recall_score,
    classification_report, confusion_matrix, precision_recall_fscore_support
)
import warnings
warnings.filterwarnings('ignore')

In [2]:
# ============================================================================
# CONFIGURA√á√ïES
# ============================================================================

DEVICE = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
IMG_SIZE = 224
BATCH_SIZE = 32

# CAMINHOS - ADAPTAR CONFORME NECESS√ÅRIO
MODELS_PATH = "../results/augmented_online/EXPW/models/"
RESULTS_PATH = "../results/cross_data/expw_for_raf_processed"
DATASET_PATH = "../data/processed/RAF-DB"  # ‚Üê Dataset base para experimento

# Criar diret√≥rios de sa√≠da
os.makedirs(RESULTS_PATH, exist_ok=True)
os.makedirs(os.path.join(RESULTS_PATH, "metrics"), exist_ok=True)
os.makedirs(os.path.join(RESULTS_PATH, "plots"), exist_ok=True)
os.makedirs(os.path.join(RESULTS_PATH, "models"), exist_ok=True)

EMOTION_LABELS = {
    'Raiva': 0, 'Nojo': 1, 'Medo': 2, 'Felicidade': 3,
    'Neutro': 4, 'Tristeza': 5, 'Surpresa': 6
}

# 3 modelos a executar
MODELS_TO_LOAD = [
    'resnet50_best.pth',
    'efficientnet_b0_best.pth',
    'vit_b_16_best.pth'
]


In [3]:
# ============================================================================
# CLASSES E FUN√á√ïES
# ============================================================================

class EmotionDataset(Dataset):
    """Custom dataset para emotion classification com dados PR√â-PROCESSADOS em escala de cinza"""
    
    def __init__(self, images, labels, transform=None):
        self.images = images  # [N, 224, 224] - escala de cinza
        self.labels = labels
        self.transform = transform
    
    def __len__(self):
        return len(self.images)
    
    def __getitem__(self, idx):
        image = self.images[idx]  # [224, 224] - cinza
        label = self.labels[idx]  # ‚Üê IMPORTANTE: pegar o label!
        
       # Manter 1 canal (n√£o converter)
        if self.transform:
            image = self.transform(image)
        else:
            from PIL import Image
            image = Image.fromarray(image, mode='L')  # ‚Üê 'L' = escala de cinza
            image = transforms.ToTensor()(image)
        
        return image, torch.tensor(label, dtype=torch.long)

def create_model(model_name):
    """
    Cria modelo base adaptado para entrada em ESCALA DE CINZA (1 canal).
    Usa classificadores SIMPLES compat√≠veis com checkpoints salvos.
    """
    
    if model_name == 'resnet50':
        model = models.resnet50(weights=models.ResNet50_Weights.IMAGENET1K_V2)
        
        # Adaptar primeira camada para 1 canal
        original_conv1 = model.conv1
        model.conv1 = nn.Conv2d(1, 64, kernel_size=7, stride=2, padding=3, bias=False)
        model.conv1.weight.data = original_conv1.weight.data.mean(dim=1, keepdim=True)
        
        # Classificador simples (compat√≠vel com checkpoint)
        num_features = model.fc.in_features
        model.fc = nn.Linear(num_features, 7)
        
    elif model_name == 'efficientnet_b0':
        model = models.efficientnet_b0(weights=models.EfficientNet_B0_Weights.IMAGENET1K_V1)
        
        # Adaptar primeira camada para 1 canal
        original_conv = model.features[0][0]
        model.features[0][0] = nn.Conv2d(1, 32, kernel_size=3, stride=2, padding=1, bias=False)
        model.features[0][0].weight.data = original_conv.weight.data.mean(dim=1, keepdim=True)
        
        # Classificador simples (compat√≠vel com checkpoint)
        # Apenas substituir a √∫ltima camada do classifier
        num_features = model.classifier[1].in_features
        model.classifier[1] = nn.Linear(num_features, 7)
        
    elif model_name == 'vit_b_16':
        model = models.vision_transformer.vit_b_16(
            weights=models.ViT_B_16_Weights.IMAGENET1K_V1
        )
        
        # Adaptar primeira camada para 1 canal
        original_conv = model.conv_proj
        model.conv_proj = nn.Conv2d(1, 768, kernel_size=16, stride=16)
        model.conv_proj.weight.data = original_conv.weight.data.mean(dim=1, keepdim=True)
        
        # Classificador simples (compat√≠vel com checkpoint)
        num_features = model.heads.head.in_features
        model.heads.head = nn.Linear(num_features, 7)
    
    return model


def load_model_checkpoint(model_name, checkpoint_path):
    """Carrega modelo do checkpoint"""
    print(f"\nüîÑ Carregando modelo: {model_name}")
    
    model = create_model(model_name)
    
    if not os.path.exists(checkpoint_path):
        print(f"‚ùå Arquivo n√£o encontrado: {checkpoint_path}")
        return None
    
    checkpoint = torch.load(checkpoint_path, map_location=DEVICE)
    
    if isinstance(checkpoint, dict) and 'model_state_dict' in checkpoint:
        model.load_state_dict(checkpoint['model_state_dict'])
        metadata = checkpoint.get('metrics', {})
        print(f"‚úì Checkpoint completo carregado")
    else:
        model.load_state_dict(checkpoint)
        metadata = {}
        print(f"‚úì State dict carregado")
    
    model = model.to(DEVICE)
    model.eval()
    
    return model, metadata


def load_dataset_from_folder(directory_path, subset='test'):
    """
    Carrega dataset PR√â-PROCESSADO (j√° em 224x224, escala de cinza).
    
    Estrutura esperada:
    - Simples: dataset/Emo√ß√£o/imagem.jpg
    - Split: dataset/train/Emo√ß√£o/imagem.jpg ou dataset/test/Emo√ß√£o/imagem.jpg
    
    Args:
        directory_path: Caminho para dataset
        subset: 'test' ou 'train' (usado se houver subdivis√£o)
    """
    import glob
    
    print(f"\nüìÇ Carregando dataset PR√â-PROCESSADO: {directory_path}")
    
    images = []
    labels = []
    
    if not os.path.exists(directory_path):
        print(f"‚ùå Diret√≥rio n√£o encontrado: {directory_path}")
        return None, None
    
    # ===== DETEC√á√ÉO AUTOM√ÅTICA DE ESTRUTURA =====
    try:
        subdirs = [d for d in os.listdir(directory_path) 
                   if os.path.isdir(os.path.join(directory_path, d))]
    except Exception as e:
        print(f"‚ùå Erro ao listar diret√≥rio: {e}")
        return None, None
    
    print(f"  Subdiret√≥rios encontrados: {subdirs}")
    
    # Verificar se tem train/test
    has_train = 'train' in subdirs
    has_test = 'test' in subdirs
    has_split = has_train or has_test
    
    if has_split:
        print(f"‚úì Estrutura com split detectada (train={has_train}, test={has_test})")
        base_path = os.path.join(directory_path, subset)
        print(f"  ‚Üí Usando subset: {subset}")
        print(f"  ‚Üí Caminho base: {base_path}")
        
        if not os.path.exists(base_path):
            print(f"‚ùå Subset '{subset}' n√£o encontrado em {directory_path}")
            print(f"  Caminho esperado: {base_path}")
            return None, None
        
        # Listar emo√ß√µes dispon√≠veis
        emotion_subdirs = [d for d in os.listdir(base_path) 
                          if os.path.isdir(os.path.join(base_path, d))]
        print(f"  Emo√ß√µes em {subset}: {emotion_subdirs}")
    else:
        print(f"‚úì Estrutura simples detectada (emo√ß√µes diretamente no diret√≥rio)")
        base_path = directory_path
        emotion_subdirs = subdirs
        print(f"  Emo√ß√µes encontradas: {emotion_subdirs}")
    
    # ===== CARREGAMENTO DE IMAGENS =====
    print(f"\n  Iniciando carregamento de dados PR√â-PROCESSADOS...")
    
    for emotion, label in EMOTION_LABELS.items():
        emotion_path = os.path.join(base_path, emotion)
        
        if not os.path.exists(emotion_path):
            print(f"  ‚ö†  '{emotion}' n√£o encontrado em {emotion_path}")
            continue
        
        # Buscar imagens
        image_files = []
        for ext in ['*.jpg', '*.jpeg', '*.png', '*.bmp']:
            pattern = os.path.join(emotion_path, ext)
            image_files.extend(glob.glob(pattern))
        
        if len(image_files) == 0:
            print(f"  ‚ö†  Nenhuma imagem em {emotion_path}")
            continue
        
        count = 0
        for img_file in image_files:
            try:
                # 1Ô∏è‚É£ Carregar como ESCALA DE CINZA (1 canal)
                img = cv2.imread(img_file, cv2.IMREAD_GRAYSCALE)
                
                if img is None:
                    continue
                
                # 2Ô∏è‚É£ Garantir dimens√£o correta (224, 224)
                if img.shape != (224, 224):
                    img = cv2.resize(img, (224, 224), interpolation=cv2.INTER_AREA)
                
                # 3Ô∏è‚É£ Converter para float [0, 1] ANTES de adicionar √† lista
                img = img.astype(np.float32) / 255.0
                
                # 4Ô∏è‚É£ Adicionar √† lista
                images.append(img)
                labels.append(label)
                count += 1
                
            except Exception as e:
                continue
        
        print(f"  ‚úì {emotion:15s}: {count:4d} imagens carregadas")
    
    if len(images) == 0:
        print("\n‚ùå Nenhuma imagem carregada!")
        return None, None
    
    print(f"\n‚úì Total: {len(images)} imagens carregadas com sucesso")
    print(f"  Formato: [N, 224, 224] - Escala de cinza (1 canal)")
    return np.array(images), np.array(labels)

def get_transforms():
    """Retorna transforma√ß√µes para imagens PR√â-PROCESSADAS"""
    val_transform = transforms.Compose([
        transforms.ToPILImage(),
        # N√ÉO REDIMENSIONA - j√° est√° 224x224
        transforms.ToTensor(),
        transforms.Normalize(mean=[0.485], std=[0.229])
    ])
    return val_transform

def evaluate_model(model, test_loader):
    """Avalia modelo e retorna m√©tricas"""
    model.eval()
    all_preds = []
    all_targets = []
    inference_times = []
    
    with torch.no_grad():
        for data, target in test_loader:
            start = time.time()
            data, target = data.to(DEVICE), target.to(DEVICE)
            output = model(data)
            inference_times.append(time.time() - start)
            
            _, predicted = output.max(1)
            all_preds.extend(predicted.cpu().numpy())
            all_targets.extend(target.cpu().numpy())
    
    accuracy = accuracy_score(all_targets, all_preds)
    f1_macro = f1_score(all_targets, all_preds, average='macro', zero_division=0)
    f1_weighted = f1_score(all_targets, all_preds, average='weighted', zero_division=0)
    precision = precision_score(all_targets, all_preds, average='macro', zero_division=0)
    recall = recall_score(all_targets, all_preds, average='macro', zero_division=0)
    
    conf_matrix = confusion_matrix(all_targets, all_preds)
    
    emotion_names = list(EMOTION_LABELS.keys())
    class_report = classification_report(
        all_targets, all_preds,
        target_names=emotion_names,
        output_dict=True,
        zero_division=0
    )
    
    return {
        'accuracy': accuracy,
        'f1_macro': f1_macro,
        'f1_weighted': f1_weighted,
        'precision': precision,
        'recall': recall,
        'conf_matrix': conf_matrix,
        'class_report': class_report,
        'y_true': np.array(all_targets),
        'y_pred': np.array(all_preds),
        'avg_inference_time': np.mean(inference_times)
    }


In [4]:
def create_comprehensive_plots(model_name, metrics, experiment_id):
    """
    Cria 12 visualiza√ß√µes completas com plots, mem√≥ria, tempo e matriz normalizada
    """
    print(f"\nüìä Gerando visualiza√ß√µes para {model_name}...")
    
    fig = plt.figure(figsize=(24, 18))
    emotion_names = list(EMOTION_LABELS.keys())
    
    # ===== 1. MATRIZ DE CONFUS√ÉO RAW =====
    ax1 = plt.subplot(3, 4, 1)
    sns.heatmap(metrics['conf_matrix'], annot=True, fmt='d', cmap='Blues',
                xticklabels=emotion_names, yticklabels=emotion_names, ax=ax1,
                cbar_kws={'label': 'Amostras'})
    ax1.set_title(f'{model_name}: Matriz de Confus√£o (Raw)', fontsize=12, fontweight='bold')
    ax1.set_ylabel('Classe Verdadeira')
    ax1.set_xlabel('Classe Predita')
    plt.setp(ax1.get_xticklabels(), rotation=45)
    
    # ===== 2. MATRIZ NORMALIZADA =====
    ax2 = plt.subplot(3, 4, 2)
    conf_norm = metrics['conf_matrix'].astype('float') / metrics['conf_matrix'].sum(axis=1)[:, np.newaxis]
    conf_norm = np.nan_to_num(conf_norm)
    sns.heatmap(conf_norm, annot=True, fmt='.3f', cmap='Greens',
                xticklabels=emotion_names, yticklabels=emotion_names, ax=ax2,
                cbar_kws={'label': 'Propor√ß√£o'})
    ax2.set_title(f'{model_name}: Matriz Normalizada (Recall)', fontsize=12, fontweight='bold')
    ax2.set_ylabel('Classe Verdadeira')
    ax2.set_xlabel('Classe Predita')
    plt.setp(ax2.get_xticklabels(), rotation=45)
    
    # ===== 3. F1-SCORE POR CLASSE =====
    ax3 = plt.subplot(3, 4, 3)
    f1_scores = [metrics['class_report'][emotion]['f1-score'] for emotion in emotion_names]
    colors = plt.cm.viridis(np.linspace(0, 1, len(emotion_names)))
    bars = ax3.bar(emotion_names, f1_scores, color=colors, alpha=0.8, edgecolor='black')
    ax3.set_title(f'{model_name}: F1-Score por Emo√ß√£o', fontsize=12, fontweight='bold')
    ax3.set_ylabel('F1-Score')
    ax3.set_ylim(0, 1)
    plt.setp(ax3.get_xticklabels(), rotation=45)
    for bar, score in zip(bars, f1_scores):
        ax3.text(bar.get_x() + bar.get_width()/2, bar.get_height() + 0.02,
                f'{score:.3f}', ha='center', va='bottom', fontsize=9)
    
    # ===== 4. PRECISION, RECALL, F1 =====
    ax4 = plt.subplot(3, 4, 4)
    precision_scores = [metrics['class_report'][emotion]['precision'] for emotion in emotion_names]
    recall_scores = [metrics['class_report'][emotion]['recall'] for emotion in emotion_names]
    x = np.arange(len(emotion_names))
    width = 0.25
    ax4.bar(x - width, precision_scores, width, label='Precision', alpha=0.8, color='lightcoral')
    ax4.bar(x, recall_scores, width, label='Recall', alpha=0.8, color='lightblue')
    ax4.bar(x + width, f1_scores, width, label='F1-Score', alpha=0.8, color='lightgreen')
    ax4.set_title(f'{model_name}: M√©tricas por Classe', fontsize=12, fontweight='bold')
    ax4.set_ylabel('Score')
    ax4.set_xticks(x)
    ax4.set_xticklabels(emotion_names, rotation=45)
    ax4.legend()
    ax4.set_ylim(0, 1)
    
    # ===== 5. DISTRIBUI√á√ÉO DE CLASSES (TESTE) =====
    ax5 = plt.subplot(3, 4, 5)
    unique, counts = np.unique(metrics['y_true'], return_counts=True)
    class_dist = dict(zip(unique, counts))
    test_counts = [class_dist.get(i, 0) for i in range(len(emotion_names))]
    colors_dist = plt.cm.Set3(np.linspace(0, 1, len(emotion_names)))
    bars = ax5.bar(emotion_names, test_counts, color=colors_dist, alpha=0.8, edgecolor='black')
    ax5.set_title(f'{model_name}: Distribui√ß√£o de Classes (Teste)', fontsize=12, fontweight='bold')
    ax5.set_ylabel('N√∫mero de Amostras')
    plt.setp(ax5.get_xticklabels(), rotation=45)
    mean_samples = np.mean(test_counts)
    ax5.axhline(y=mean_samples, color='red', linestyle='--', alpha=0.7, label=f'M√©dia: {mean_samples:.0f}')
    ax5.legend()
    for bar, count in zip(bars, test_counts):
        ax5.text(bar.get_x() + bar.get_width()/2, bar.get_height() + max(test_counts)*0.02,
                f'{count}', ha='center', va='bottom', fontweight='bold')
    
    # ===== 6. CORRELA√á√ÉO AMOSTRAS vs F1 =====
    ax6 = plt.subplot(3, 4, 6)
    support_counts = [metrics['class_report'][emotion]['support'] for emotion in emotion_names]
    ax6.scatter(support_counts, f1_scores, c=support_counts, cmap='viridis', s=100, alpha=0.7, edgecolors='black')
    for i, emotion in enumerate(emotion_names):
        ax6.annotate(emotion, (support_counts[i], f1_scores[i]), xytext=(5, 5), 
                    textcoords='offset points', fontsize=8)
    z = np.polyfit(support_counts, f1_scores, 1)
    p = np.poly1d(z)
    ax6.plot(support_counts, p(support_counts), "r--", alpha=0.8)
    correlation = np.corrcoef(support_counts, f1_scores)[0, 1]
    ax6.set_title(f'{model_name}: Amostras vs Performance', fontsize=12, fontweight='bold')
    ax6.set_xlabel('Suporte (Test)')
    ax6.set_ylabel('F1-Score')
    ax6.grid(True, alpha=0.3)
    ax6.text(0.05, 0.95, f'Correla√ß√£o: {correlation:.3f}', transform=ax6.transAxes,
            bbox=dict(boxstyle='round', facecolor='white', alpha=0.8), verticalalignment='top')
    
    # ===== 7. HEATMAP DE ERROS =====
    ax7 = plt.subplot(3, 4, 7)
    error_matrix = metrics['conf_matrix'].copy()
    np.fill_diagonal(error_matrix, 0)
    error_norm = error_matrix.astype('float') / metrics['conf_matrix'].sum(axis=1)[:, np.newaxis]
    error_norm = np.nan_to_num(error_norm)
    sns.heatmap(error_norm, annot=True, fmt='.3f', cmap='Reds',
                xticklabels=emotion_names, yticklabels=emotion_names, ax=ax7,
                cbar_kws={'label': 'Taxa de Erro'})
    ax7.set_title(f'{model_name}: Heatmap de Erros', fontsize=12, fontweight='bold')
    ax7.set_ylabel('Classe Verdadeira')
    ax7.set_xlabel('Classe Predita (Erro)')
    plt.setp(ax7.get_xticklabels(), rotation=45)
    
    # ===== 8. MACRO vs WEIGHTED =====
    ax8 = plt.subplot(3, 4, 8)
    precision_weighted, recall_weighted, f1_weighted, _ = precision_recall_fscore_support(
        metrics['y_true'], metrics['y_pred'], average='weighted', zero_division=0
    )
    metrics_comparison = {
        'Precision': [metrics['precision'], precision_weighted],
        'Recall': [metrics['recall'], recall_weighted],
        'F1-Score': [metrics['f1_macro'], metrics['f1_weighted']]
    }
    x = np.arange(len(metrics_comparison))
    width = 0.35
    macro_vals = [metrics_comparison[m][0] for m in metrics_comparison]
    weighted_vals = [metrics_comparison[m][1] for m in metrics_comparison]
    ax8.bar(x - width/2, macro_vals, width, label='Macro (Desbalanceado)', alpha=0.8, color='lightcoral')
    ax8.bar(x + width/2, weighted_vals, width, label='Weighted (Balanceado)', alpha=0.8, color='lightblue')
    ax8.set_title(f'{model_name}: Macro vs Weighted', fontsize=12, fontweight='bold')
    ax8.set_ylabel('Score')
    ax8.set_xticks(x)
    ax8.set_xticklabels(metrics_comparison.keys())
    ax8.legend()
    ax8.set_ylim(0, 1)
    for bars in [ax8.patches[i::len(metrics_comparison)] for i in range(2)]:
        for bar in bars:
            height = bar.get_height()
            ax8.text(bar.get_x() + bar.get_width()/2, height + 0.01,
                    f'{height:.3f}', ha='center', va='bottom', fontsize=9)
    
    # ===== 9. TEMPO DE INFER√äNCIA =====
    ax9 = plt.subplot(3, 4, 9)
    time_data = {
        'Tempo M√©dio (ms)': metrics['avg_inference_time'] * 1000,
        'Throughput (img/s)': 1.0 / metrics['avg_inference_time'] if metrics['avg_inference_time'] > 0 else 0
    }
    colors_time = ['#1f77b4', '#ff7f0e']
    bars = ax9.bar(range(len(time_data)), list(time_data.values()), color=colors_time, alpha=0.8, edgecolor='black')
    ax9.set_title(f'{model_name}: Performance de Infer√™ncia', fontsize=12, fontweight='bold')
    ax9.set_xticks(range(len(time_data)))
    ax9.set_xticklabels(time_data.keys(), rotation=45)
    ax9.set_ylabel('Valor')
    for bar, (key, val) in zip(bars, time_data.items()):
        ax9.text(bar.get_x() + bar.get_width()/2, bar.get_height() + bar.get_height()*0.02,
                f'{val:.2f}', ha='center', va='bottom', fontweight='bold')
    
    # ===== 10. RESUMO DE M√âTRICAS =====
    ax10 = plt.subplot(3, 4, 10)
    ax10.axis('off')
    summary_text = f"""
RESUMO - {model_name.upper()}

ACUR√ÅCIA E F1:
‚Ä¢ Accuracy: {metrics['accuracy']:.4f} ({metrics['accuracy']*100:.2f}%)
‚Ä¢ F1-Macro: {metrics['f1_macro']:.4f}
‚Ä¢ F1-Weighted: {metrics['f1_weighted']:.4f}

PRECIS√ÉO E RECALL:
‚Ä¢ Precision: {metrics['precision']:.4f}
‚Ä¢ Recall: {metrics['recall']:.4f}

INFER√äNCIA:
‚Ä¢ Tempo M√©dio: {metrics['avg_inference_time']*1000:.2f} ms
‚Ä¢ Throughput: {1.0/metrics['avg_inference_time']:.1f} img/s
    """
    ax10.text(0.05, 0.95, summary_text, fontsize=11, verticalalignment='top',
             transform=ax10.transAxes, family='monospace',
             bbox=dict(boxstyle='round', facecolor='lightblue', alpha=0.8))
    
    # ===== 11. PROCESSO (Placeholder - seria do hist√≥rico de treinamento) =====
    ax11 = plt.subplot(3, 4, 11)
    ax11.text(0.5, 0.5, f'Modelo: {model_name}\nData: {datetime.now().strftime("%Y-%m-%d %H:%M:%S")}\n\nDataset: Teste',
             ha='center', va='center', fontsize=12, fontweight='bold',
             transform=ax11.transAxes,
             bbox=dict(boxstyle='round', facecolor='lightyellow', alpha=0.8))
    ax11.axis('off')
    
    # ===== 12. INFORMA√á√ïES DE SISTEMA =====
    ax12 = plt.subplot(3, 4, 12)
    ax12.axis('off')
    process = psutil.Process()
    memory_info = process.memory_info()
    system_text = f"""
RECURSOS COMPUTACIONAIS

Dispositivo: {DEVICE}
CPU: {os.cpu_count()} cores

Mem√≥ria:
‚Ä¢ Uso Atual: {memory_info.rss / 1024**3:.2f} GB
‚Ä¢ RSS: {memory_info.rss / 1024**2:.1f} MB

CUDA (se dispon√≠vel):
"""
    if torch.cuda.is_available():
        system_text += f"""‚Ä¢ GPU: {torch.cuda.get_device_name(0)}
‚Ä¢ Mem√≥ria GPU: {torch.cuda.get_device_properties(0).total_memory / 1024**3:.1f} GB
‚Ä¢ Alocada: {torch.cuda.memory_allocated() / 1024**2:.1f} MB
"""
    ax12.text(0.05, 0.95, system_text, fontsize=10, verticalalignment='top',
             transform=ax12.transAxes, family='monospace',
             bbox=dict(boxstyle='round', facecolor='lightgreen', alpha=0.8))
    
    plt.tight_layout()
    plot_path = os.path.join(RESULTS_PATH, 'plots', f'{model_name}_comprehensive_{experiment_id}.png')
    plt.savefig(plot_path, dpi=300, bbox_inches='tight')
    print(f"‚úì Plot salvo: {plot_path}")
    plt.close()


def run_single_model_evaluation(model_file, X_test, y_test):
    """Executa avalia√ß√£o completa de um modelo"""
    
    # Extrair nome do modelo
    model_name = model_file.replace('_best.pth', '')
    
    print("\n" + "="*80)
    print(f"EXECUTANDO: {model_name.upper()}")
    print("="*80)
    
    # 1. Carregar modelo
    checkpoint_path = os.path.join(MODELS_PATH, model_file)
    result = load_model_checkpoint(model_name, checkpoint_path)
    
    if result is None:
        print(f"‚ùå Falha ao carregar modelo")
        return None
    
    model, metadata = result
    
    # 2. Preparar dataset
    transform = get_transforms()
    dataset = EmotionDataset(X_test, y_test, transform=transform)
    test_loader = DataLoader(dataset, batch_size=BATCH_SIZE, shuffle=False, num_workers=4)
    
    # 3. Avaliar modelo
    print(f"\n‚è±Ô∏è  Avaliando modelo...")
    start_time = time.time()
    metrics = evaluate_model(model, test_loader)
    evaluation_time = time.time() - start_time
    metrics['evaluation_time'] = evaluation_time
    
    # 4. Exibir resultados
    print(f"\nüìä RESULTADOS PARA {model_name.upper()}:")
    print(f"  Accuracy: {metrics['accuracy']:.4f} ({metrics['accuracy']*100:.2f}%)")
    print(f"  F1-Macro: {metrics['f1_macro']:.4f}")
    print(f"  F1-Weighted: {metrics['f1_weighted']:.4f}")
    print(f"  Precision: {metrics['precision']:.4f}")
    print(f"  Recall: {metrics['recall']:.4f}")
    print(f"  Tempo de Infer√™ncia M√©dio: {metrics['avg_inference_time']*1000:.2f} ms")
    print(f"  Tempo Total de Avalia√ß√£o: {evaluation_time:.2f} s")
    
    # 5. Criar visualiza√ß√µes
    timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
    create_comprehensive_plots(model_name, metrics, timestamp)
    
    # 6. Salvar m√©tricas em CSV
    metrics_df = pd.DataFrame({
        'model_name': [model_name],
        'accuracy': [metrics['accuracy']],
        'f1_macro': [metrics['f1_macro']],
        'f1_weighted': [metrics['f1_weighted']],
        'precision': [metrics['precision']],
        'recall': [metrics['recall']],
        'avg_inference_time_ms': [metrics['avg_inference_time'] * 1000],
        'evaluation_time_s': [evaluation_time],
        'timestamp': [timestamp]
    })
    
    csv_path = os.path.join(RESULTS_PATH, 'metrics', f'{model_name}_metrics_{timestamp}.csv')
    metrics_df.to_csv(csv_path, index=False)
    print(f"\n‚úì M√©tricas salvas: {csv_path}")
    
    return metrics

In [5]:
def main():
    """Pipeline principal"""
    
    print("\n" + "="*80)
    print("PIPELINE DE AVALIA√á√ÉO DE 3 MODELOS TREINADOS")
    print("="*80)
    print(f"Device: {DEVICE}")
    print(f"Dataset Path: {DATASET_PATH}")
    print(f"Models Path: {MODELS_PATH}")
    print(f"Results Path: {RESULTS_PATH}")
    
    # 1. Carregar dataset UMA VEZ
    print("\n" + "="*80)
    print("CARREGANDO DATASET DE TESTE")
    print("="*80)
    
    X_test, y_test = load_dataset_from_folder(DATASET_PATH, subset="test")
    
    if X_test is None:
        print("‚ùå Falha ao carregar dataset")
        return
    
    # 2. Executar cada modelo sequencialmente
    all_results = []
    
    for model_file in MODELS_TO_LOAD:
        metrics = run_single_model_evaluation(model_file, X_test, y_test)
        if metrics is not None:
            all_results.append({
                'model': model_file.replace('_best.pth', ''),
                'accuracy': metrics['accuracy'],
                'f1_macro': metrics['f1_macro'],
                'f1_weighted': metrics['f1_weighted'],
                'precision': metrics['precision'],
                'recall': metrics['recall'],
                'inference_time_ms': metrics['avg_inference_time'] * 1000
            })
    
    # 3. Compara√ß√£o final
    print("\n" + "="*80)
    print("COMPARA√á√ÉO FINAL DOS 3 MODELOS")
    print("="*80)
    
    if all_results:
        df_comparison = pd.DataFrame(all_results)
        print("\n" + df_comparison.to_string(index=False))
        
        # Salvar tabela de compara√ß√£o
        comparison_path = os.path.join(RESULTS_PATH, 'metrics', f'comparison_{datetime.now().strftime("%Y%m%d_%H%M%S")}.csv')
        df_comparison.to_csv(comparison_path, index=False)
        print(f"\n‚úì Compara√ß√£o salva: {comparison_path}")
        
        # Criar gr√°fico comparativo
        fig, axes = plt.subplots(2, 3, figsize=(15, 10))
        
        metrics_to_plot = ['accuracy', 'f1_macro', 'f1_weighted', 'precision', 'recall', 'inference_time_ms']
        for idx, metric in enumerate(metrics_to_plot):
            ax = axes.flatten()[idx]
            colors = plt.cm.Set3(np.linspace(0, 1, len(df_comparison)))
            bars = ax.bar(df_comparison['model'], df_comparison[metric], color=colors, edgecolor='black', alpha=0.8)
            ax.set_title(f'{metric.upper()}', fontsize=12, fontweight='bold')
            ax.set_ylabel('Valor')
            plt.setp(ax.get_xticklabels(), rotation=45)
            
            # Adicionar valores nas barras
            for bar, val in zip(bars, df_comparison[metric]):
                ax.text(bar.get_x() + bar.get_width()/2, bar.get_height() + bar.get_height()*0.02,
                       f'{val:.3f}', ha='center', va='bottom', fontweight='bold', fontsize=9)
        
        plt.tight_layout()
        comparison_plot_path = os.path.join(RESULTS_PATH, 'plots', f'comparison_{datetime.now().strftime("%Y%m%d_%H%M%S")}.png')
        plt.savefig(comparison_plot_path, dpi=300, bbox_inches='tight')
        print(f"‚úì Gr√°fico comparativo salvo: {comparison_plot_path}")
        plt.close()
    
    print("\n" + "="*80)
    print("‚úÖ PIPELINE COMPLETO!")
    print("="*80)
    print(f"üìÅ Resultados salvos em: {RESULTS_PATH}")


if __name__ == "__main__":
    main()





PIPELINE DE AVALIA√á√ÉO DE 3 MODELOS TREINADOS
Device: cuda
Dataset Path: ../data/processed/RAF-DB
Models Path: ../results/augmented_online/EXPW/models/
Results Path: ../results/cross_data/expw_for_raf_processed

CARREGANDO DATASET DE TESTE

üìÇ Carregando dataset PR√â-PROCESSADO: ../data/processed/RAF-DB
  Subdiret√≥rios encontrados: ['test', 'train']
‚úì Estrutura com split detectada (train=True, test=True)
  ‚Üí Usando subset: test
  ‚Üí Caminho base: ../data/processed/RAF-DB/test
  Emo√ß√µes em test: ['Tristeza', 'Raiva', 'Neutro', 'Surpresa', 'Felicidade', 'Medo', 'Nojo']

  Iniciando carregamento de dados PR√â-PROCESSADOS...
  ‚úì Raiva          :  162 imagens carregadas
  ‚úì Nojo           :  160 imagens carregadas
  ‚úì Medo           :   74 imagens carregadas
  ‚úì Felicidade     : 1185 imagens carregadas
  ‚úì Neutro         :  680 imagens carregadas
  ‚úì Tristeza       :  478 imagens carregadas
  ‚úì Surpresa       :  329 imagens carregadas

‚úì Total: 3068 imagens carre