In [34]:
"""
=============================================================================
TREINAMENTO RAF-DB COM AUGMENTATION ONLINE AGRESSIVO (80%)
=============================================================================
Pipeline de treinamento com augmentation online forte:
1. Carrega dataset augmented_raf (já balanceado offline)
2. Aplica augmentation ONLINE agressivo (p=0.8) - 80% das imagens
3. Treina múltiplos modelos: ResNet50, EfficientNet-B0, ViT
4. Early stopping para evitar overfitting
5. Métricas completas para comparação

ESTRATÉGIA:
- Dataset base: augmented_raf (balanceado offline)
- Augmentation online: MUITO AGRESSIVO (p=0.8) - NÃO salva em disco
- Modelos: 3 arquiteturas diferentes
- Métricas: Completas (Acc, F1, Precision, Recall, por classe)
- Early Stopping: Patience=15

DIFERENCIAL:
- 80% das imagens sofrem transformação A CADA ÉPOCA
- Máxima variabilidade sem explodir armazenamento
- Transformações mais fortes que offline

=============================================================================
"""

import os
import cv2
import numpy as np
import pandas as pd
from datetime import datetime, timedelta
import time
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import Dataset, DataLoader
import torchvision.models as models
import albumentations as A
from albumentations.pytorch import ToTensorV2
from tqdm import tqdm
from sklearn.metrics import (
    accuracy_score, 
    f1_score, 
    precision_score, 
    recall_score,
    classification_report,
    confusion_matrix
)
import matplotlib.pyplot as plt
import seaborn as sns
import warnings
import gc

warnings.filterwarnings('ignore')

In [35]:
# =============================================================================
# CONFIGURAÇÕES
# =============================================================================

# Device
DEVICE = "cuda" if torch.cuda.is_available() else "cpu"

# Caminhos (augmented_raf - já balanceado)
DATASET_PATH = r"../data/augmented/RAF-DB"
RESULTS_PATH = r"../results/augmented_online/RAF-DB"
MODELS_PATH = os.path.join(RESULTS_PATH, "models")
METRICS_PATH = os.path.join(RESULTS_PATH, "metrics")
PLOTS_PATH = os.path.join(RESULTS_PATH, "plots")

# Criar diretórios
os.makedirs(MODELS_PATH, exist_ok=True)
os.makedirs(METRICS_PATH, exist_ok=True)
os.makedirs(PLOTS_PATH, exist_ok=True)

# Parâmetros de Treinamento
EPOCHS = 100
LR = 1e-4
BATCH_SIZE = 32  # Reduzido para caber na GPU com augmentation
NUM_WORKERS = 4

# Early Stopping
EARLY_STOP_PATIENCE = 15
EARLY_STOP_MIN_DELTA = 0.001

# Pesos para métrica combinada
F1_WEIGHT = 0.6
ACC_WEIGHT = 0.4

# Mapeamento de classes RAF-DB
EMOTION_LABELS = {
    'Raiva': 0, 
    'Nojo': 1, 
    'Medo': 2, 
    'Felicidade': 3,
    'Neutro': 4, 
    'Tristeza': 5, 
    'Surpresa': 6
}

# Configuração de memória GPU
if DEVICE == "cuda":
    torch.cuda.empty_cache()
    gc.collect()
    os.environ['PYTORCH_CUDA_ALLOC_CONF'] = 'expandable_segments:True'

print("="*80)
print("TREINAMENTO RAF-DB - AUGMENTATION ONLINE AGRESSIVO (80%)")
print("="*80)
print(f"Dispositivo: {DEVICE}")
print(f"Dataset: {DATASET_PATH}")
print(f"Épocas máximas: {EPOCHS}")
print(f"Batch size: {BATCH_SIZE}")
print(f"Augmentation: p=0.8 (80% das imagens)")
print(f"Early Stopping: Patience={EARLY_STOP_PATIENCE}")
print("="*80)

TREINAMENTO RAF-DB - AUGMENTATION ONLINE AGRESSIVO (80%)
Dispositivo: cuda
Dataset: ../data/augmented/RAF-DB
Épocas máximas: 100
Batch size: 32
Augmentation: p=0.8 (80% das imagens)
Early Stopping: Patience=15


In [36]:
# =============================================================================
# EARLY STOPPING
# =============================================================================

class EarlyStopping:
    """
    Early Stopping para interromper quando métrica não melhora.
    """
    def __init__(self, patience=15, min_delta=0.001, mode='max', verbose=True):
        self.patience = patience
        self.min_delta = min_delta
        self.mode = mode
        self.verbose = verbose
        self.counter = 0
        self.best_score = None
        self.early_stop = False
        self.val_score_min = np.inf if mode == 'min' else -np.inf
        
    def __call__(self, val_score):
        if self.mode == 'max':
            score = val_score
        else:
            score = -val_score
            
        if self.best_score is None:
            self.best_score = score
            self.val_score_min = val_score
            if self.verbose:
                print(f'Early Stopping: Baseline: {val_score:.4f}')
        elif score < self.best_score + self.min_delta:
            self.counter += 1
            if self.verbose:
                print(f'Early Stopping: Sem melhoria ({self.counter}/{self.patience})')
            if self.counter >= self.patience:
                self.early_stop = True
                if self.verbose:
                    print(f'Early Stopping: ATIVADO após {self.patience} épocas')
        else:
            if self.verbose:
                improvement = val_score - self.val_score_min
                print(f'Early Stopping: Melhoria! ({val_score:.4f}, Δ={improvement:+.4f})')
            self.best_score = score
            self.val_score_min = val_score
            self.counter = 0
            
        return self.early_stop


In [37]:
# =============================================================================
# AUGMENTATION ONLINE - AGRESSIVO (80%)
# =============================================================================

def create_train_transform():
    """
    Augmentation ONLINE AGRESSIVO - 80% das imagens transformadas.
    
    Estratégia:
    - p=0.8: 80% das imagens sofrem transformação
    - Transformações FORTES compatíveis com GRAYSCALE
    - Aplicado A CADA ÉPOCA (não salvo em disco)
    """
    return A.Compose([
        # 80% de chance de aplicar UMA transformação FORTE
        A.OneOf([
            # ============================================================
            # TRANSFORMAÇÕES GEOMÉTRICAS (compatíveis com grayscale)
            # ============================================================
            A.HorizontalFlip(p=1.0),
            
            A.Rotate(limit=20, p=1.0),
            
            A.ShiftScaleRotate(
                shift_limit=0.15,
                scale_limit=0.15,
                rotate_limit=20,
                p=1.0
            ),
            
            A.Affine(
                scale=(0.85, 1.15),
                translate_percent=(-0.15, 0.15),
                rotate=(-20, 20),
                shear=(-10, 10),
                p=1.0
            ),
            
            A.Perspective(scale=(0.05, 0.1), p=1.0),
            
            # ============================================================
            # TRANSFORMAÇÕES DE INTENSIDADE (compatíveis com grayscale)
            # ============================================================
            A.RandomBrightnessContrast(
                brightness_limit=0.3,
                contrast_limit=0.3,
                p=1.0
            ),
            
            A.RandomGamma(gamma_limit=(80, 120), p=1.0),
            
            A.RandomToneCurve(scale=0.3, p=1.0),
            
            A.Equalize(p=1.0),
            
            A.Sharpen(alpha=(0.2, 0.5), lightness=(0.5, 1.0), p=1.0),
            
            # ============================================================
            # RUÍDOS (compatíveis com grayscale)
            # ============================================================
            A.GaussNoise(var_limit=(15.0, 50.0), p=1.0),
            
            A.GaussianBlur(blur_limit=(3, 7), p=1.0),
            
            A.MotionBlur(blur_limit=7, p=1.0),
            
            A.MedianBlur(blur_limit=7, p=1.0),
            
            # ============================================================
            # DEGRADAÇÃO (compatíveis com grayscale)
            # ============================================================
            A.PixelDropout(dropout_prob=0.02, p=1.0),
            
            A.CoarseDropout(
                max_holes=8,
                max_height=16,
                max_width=16,
                min_holes=1,
                min_height=4,
                min_width=4,
                p=1.0
            ),
            
            A.Downscale(
                scale_min=0.5,
                scale_max=0.75,
                interpolation=cv2.INTER_LINEAR,
                p=1.0
            ),
            
            # ============================================================
            # DISTORÇÕES (compatíveis com grayscale)
            # ============================================================
            A.ElasticTransform(
                alpha=2,
                sigma=50,
                alpha_affine=30,
                p=1.0
            ),
            
            A.GridDistortion(
                num_steps=5,
                distort_limit=0.3,
                p=1.0
            ),
            
            A.OpticalDistortion(
                distort_limit=0.15,
                shift_limit=0.15,
                p=1.0
            ),
            
        ], p=0.8),  # 80% de chance de aplicar UMA transformação
        
        # Redimensionar (garantia)
        A.Resize(height=224, width=224),
        
        # Normalização
        A.Normalize(mean=[0.5], std=[0.5]),
        ToTensorV2(),
    ])


def create_val_transform():
    """
    Transformação para validação/teste (SEM augmentation).
    """
    return A.Compose([
        A.Resize(height=224, width=224),
        A.Normalize(mean=[0.5], std=[0.5]),
        ToTensorV2(),
    ])


In [38]:
# =============================================================================
# DATASET CUSTOMIZADO
# =============================================================================

class EmotionDataset(Dataset):
    """
    Dataset para RAF-DB com augmentation online.
    """
    def __init__(self, root_dir, transform=None):
        self.root_dir = root_dir
        self.transform = transform
        self.image_paths = []
        self.labels = []

        # Carregar caminhos e labels
        for class_name, class_idx in EMOTION_LABELS.items():
            class_path = os.path.join(self.root_dir, class_name)
            if os.path.isdir(class_path):
                for img_name in os.listdir(class_path):
                    if img_name.lower().endswith(('.jpg', '.jpeg', '.png', '.bmp')):
                        self.image_paths.append(os.path.join(class_path, img_name))
                        self.labels.append(class_idx)

        if len(self.image_paths) == 0:
            raise ValueError(f"Nenhuma imagem em {root_dir}")

    def __len__(self):
        return len(self.image_paths)
    
    def __getitem__(self, idx):
        img_path = self.image_paths[idx]
        label = self.labels[idx]

        # Ler imagem em grayscale
        image = cv2.imread(img_path, cv2.IMREAD_GRAYSCALE)
        
        if image is None:
            raise ValueError(f"Erro ao carregar: {img_path}")
        
        # Aplicar augmentation (se houver)
        if self.transform:
            augmented = self.transform(image=image)
            image = augmented['image']

        return image, torch.tensor(label, dtype=torch.long)


In [39]:
# =============================================================================
# MODELOS
# =============================================================================

def create_model(model_name, num_classes=7):
    """
    Cria modelo baseado no nome.
    """
    if model_name == 'resnet50':
        model = models.resnet50(weights='IMAGENET1K_V1')
        
        # Adaptar primeira camada para grayscale
        original_weights = model.conv1.weight.data
        avg_weights = torch.mean(original_weights, dim=1, keepdim=True)
        model.conv1 = nn.Conv2d(1, 64, kernel_size=7, stride=2, padding=3, bias=False)
        model.conv1.weight.data = avg_weights
        
        # Camada final
        num_features = model.fc.in_features
        model.fc = nn.Linear(num_features, num_classes)
    
    elif model_name == 'efficientnet_b0':
        model = models.efficientnet_b0(weights='IMAGENET1K_V1')
        
        # Adaptar primeira camada
        original_weights = model.features[0][0].weight.data
        avg_weights = torch.mean(original_weights, dim=1, keepdim=True)
        model.features[0][0] = nn.Conv2d(1, 32, kernel_size=3, stride=2, padding=1, bias=False)
        model.features[0][0].weight.data = avg_weights
        
        # Camada final
        num_features = model.classifier[1].in_features
        model.classifier[1] = nn.Linear(num_features, num_classes)
    
    elif model_name == 'vit_b_16':
        try:
            model = models.vit_b_16(weights='IMAGENET1K_V1')
            
            # Adaptar para grayscale
            original_weights = model.conv_proj.weight.data
            avg_weights = torch.mean(original_weights, dim=1, keepdim=True)
            model.conv_proj = nn.Conv2d(1, 768, kernel_size=16, stride=16)
            model.conv_proj.weight.data = avg_weights
            
            # Camada final
            model.heads.head = nn.Linear(model.heads.head.in_features, num_classes)
            
        except Exception as e:
            print(f"❌ Erro ao criar ViT: {e}")
            return None
    else:
        raise ValueError(f"Modelo não suportado: {model_name}")
    
    print(f"✓ Modelo '{model_name}' criado")
    return model

In [40]:
# =============================================================================
# AVALIAÇÃO FINAL
# =============================================================================

def evaluate_model(model_name, test_loader, device):
    """
    Avalia modelo no conjunto de teste.
    """
    print(f"\n{'='*80}")
    print(f"AVALIAÇÃO FINAL: {model_name.upper()}")
    print(f"{'='*80}")
    
    # Carregar melhor modelo
    model = create_model(model_name)
    model_path = os.path.join(MODELS_PATH, f"{model_name}_best.pth")
    
    if not os.path.exists(model_path):
        print(f"❌ Modelo não encontrado: {model_path}")
        return None
    
    checkpoint = torch.load(model_path, map_location=device)
    model.load_state_dict(checkpoint['model_state_dict'])
    model.to(device)
    model.eval()
    
    # Predições
    test_preds, test_labels = validate(model, test_loader, device)
    
    # Métricas
    test_accuracy = accuracy_score(test_labels, test_preds)
    test_f1_macro = f1_score(test_labels, test_preds, average='macro')
    test_f1_weighted = f1_score(test_labels, test_preds, average='weighted')
    test_precision = precision_score(test_labels, test_preds, average='macro', zero_division=0)
    test_recall = recall_score(test_labels, test_preds, average='macro', zero_division=0)
    
    # Relatório por classe
    class_names = list(EMOTION_LABELS.keys())
    class_report = classification_report(
        test_labels, test_preds,
        target_names=class_names,
        output_dict=True,
        zero_division=0
    )
    
    # Matriz de confusão
    conf_matrix = confusion_matrix(test_labels, test_preds)
    
    print(f"\n📊 MÉTRICAS DE TESTE:")
    print(f"  Accuracy:         {test_accuracy:.4f} ({test_accuracy*100:.2f}%)")
    print(f"  F1-Score (Macro): {test_f1_macro:.4f}")
    print(f"  F1-Score (Wtd):   {test_f1_weighted:.4f}")
    print(f"  Precision:        {test_precision:.4f}")
    print(f"  Recall:           {test_recall:.4f}")
    
    print(f"\n📋 RELATÓRIO POR CLASSE:")
    print(classification_report(test_labels, test_preds, target_names=class_names, zero_division=0))
    
    # Plotar matriz de confusão
    plt.figure(figsize=(10, 8))
    sns.heatmap(conf_matrix, annot=True, fmt='d', cmap='Blues',
                xticklabels=class_names, yticklabels=class_names)
    plt.title(f'Matriz de Confusão - {model_name.upper()}', fontsize=14, fontweight='bold')
    plt.ylabel('Classe Real')
    plt.xlabel('Classe Predita')
    plt.tight_layout()
    plt.savefig(os.path.join(PLOTS_PATH, f'confusion_matrix_{model_name}.png'), dpi=300)
    plt.close()
    
    return {
        'model_name': model_name,
        'test_accuracy': test_accuracy,
        'test_f1_macro': test_f1_macro,
        'test_f1_weighted': test_f1_weighted,
        'test_precision': test_precision,
        'test_recall': test_recall,
        'class_report': class_report,
        'confusion_matrix': conf_matrix.tolist()
    }


In [41]:
# =============================================================================
# FUNÇÃO PRINCIPAL
# =============================================================================

def main():
    """
    Função principal que coordena todo o treinamento.
    """
    # Preparar datasets
    print("\n📂 Carregando datasets...")
    
    train_transform = create_train_transform()
    val_transform = create_val_transform()
    
    train_dataset = EmotionDataset(
        root_dir=os.path.join(DATASET_PATH, 'train'),
        transform=train_transform
    )
    
    test_dataset = EmotionDataset(
        root_dir=os.path.join(DATASET_PATH, 'test'),
        transform=val_transform
    )
    
    print(f"✓ Dataset de treino: {len(train_dataset)} imagens")
    print(f"✓ Dataset de teste: {len(test_dataset)} imagens")
    
    # DataLoaders
    train_loader = DataLoader(
        train_dataset,
        batch_size=BATCH_SIZE,
        shuffle=True,
        num_workers=NUM_WORKERS,
        pin_memory=True,
        drop_last=True
    )
    
    test_loader = DataLoader(
        test_dataset,
        batch_size=BATCH_SIZE,
        shuffle=False,
        num_workers=NUM_WORKERS,
        pin_memory=True,
        drop_last=False
    )
    
    # Modelos para treinar
    models_to_train = ['efficientnet_b0', 'efficientvit_m5']  # Adicione 'efficientvit_m5' se timm instalado
    
    # Resultados
    all_results = []
    
    # Treinar cada modelo
    for model_name in models_to_train:
        # Early stopping individual
        early_stopping = EarlyStopping(
            patience=EARLY_STOP_PATIENCE,
            min_delta=EARLY_STOP_MIN_DELTA,
            mode='max',
            verbose=True
        )
        
        # Treinar
        train_result = train_model(
            model_name,
            train_loader,
            test_loader,  # Usar test como validação
            DEVICE,
            EPOCHS,
            early_stopping
        )
        
        if train_result:
            # Avaliar
            eval_result = evaluate_model(model_name, test_loader, DEVICE)
            
            if eval_result:
                # Combinar resultados
                combined_result = {**train_result, **eval_result}
                all_results.append(combined_result)
                
                # Salvar métricas individuais
                metrics_df = pd.DataFrame([combined_result])
                metrics_df.to_csv(
                    os.path.join(METRICS_PATH, f'metrics_{model_name}.csv'),
                    index=False
                )
    
    # Salvar comparação
    if all_results:
        comparison_df = pd.DataFrame(all_results)
        comparison_df.to_csv(
            os.path.join(METRICS_PATH, f'comparison_all_models_{model_name}.csv'),
            index=False
        )
        
        print(f"\n{'='*80}")
        print("COMPARAÇÃO FINAL DE MODELOS")
        print(f"{'='*80}")
        print(comparison_df[['model_name', 'test_accuracy', 'test_f1_macro', 'best_epoch', 'epochs_completed']].to_string(index=False))
        print(f"{'='*80}")
    
    print(f"\n✅ Treinamento completo finalizado!")
    print(f"📁 Resultados salvos em: {RESULTS_PATH}")


In [42]:
# =============================================================================
# TREINAMENTO
# =============================================================================

def train_epoch(model, train_loader, criterion, optimizer, device, scaler):
    """Treina por uma época."""
    model.train()
    running_loss = 0.0
    
    for images, labels in train_loader:
        images, labels = images.to(device), labels.to(device)
        
        # Forward com mixed precision
        with torch.cuda.amp.autocast(enabled=(device == "cuda")):
            outputs = model(images)
            loss = criterion(outputs, labels)
        
        # Backward
        optimizer.zero_grad()
        scaler.scale(loss).backward()
        scaler.step(optimizer)
        scaler.update()
        
        running_loss += loss.item()
    
    return running_loss / len(train_loader)


def validate(model, val_loader, device):
    """Valida o modelo."""
    model.eval()
    all_preds = []
    all_labels = []
    
    with torch.no_grad():
        for images, labels in val_loader:
            images, labels = images.to(device), labels.to(device)
            
            with torch.cuda.amp.autocast(enabled=(device == "cuda")):
                outputs = model(images)
            
            preds = torch.argmax(outputs, dim=1)
            all_preds.extend(preds.cpu().numpy())
            all_labels.extend(labels.cpu().numpy())
    
    return np.array(all_preds), np.array(all_labels)


def train_model(model_name, train_loader, val_loader, device, epochs, early_stopping):
    """Treina modelo completo."""
    print(f"\n{'='*80}")
    print(f"TREINANDO: {model_name.upper()}")
    print(f"{'='*80}")
    
    # Criar modelo
    model = create_model(model_name)
    if model is None:
        return None
    
    model.to(device)
    
    # Configuração
    criterion = nn.CrossEntropyLoss()
    optimizer = optim.AdamW(model.parameters(), lr=LR)
    scaler = torch.cuda.amp.GradScaler(enabled=(device == "cuda"))
    
    # Histórico
    history = {
        'train_loss': [],
        'val_accuracy': [],
        'val_f1_macro': [],
        'val_precision': [],
        'val_recall': [],
        'combined_metric': []
    }
    
    best_combined_metric = 0.0
    best_val_f1 = 0.0
    best_val_accuracy = 0.0
    best_epoch = 0
    
    start_time = time.time()
    
    # Loop de treinamento
    for epoch in range(epochs):
        epoch_start = time.time()
        
        # Treinar
        train_loss = train_epoch(model, train_loader, criterion, optimizer, device, scaler)
        
        # Validar
        val_preds, val_labels = validate(model, val_loader, device)
        
        # Métricas
        val_accuracy = accuracy_score(val_labels, val_preds)
        val_f1_macro = f1_score(val_labels, val_preds, average='macro')
        val_precision = precision_score(val_labels, val_preds, average='macro', zero_division=0)
        val_recall = recall_score(val_labels, val_preds, average='macro', zero_division=0)
        
        # Métrica combinada
        combined_metric = (F1_WEIGHT * val_f1_macro) + (ACC_WEIGHT * val_accuracy)
        
        # Salvar histórico
        history['train_loss'].append(train_loss)
        history['val_accuracy'].append(val_accuracy)
        history['val_f1_macro'].append(val_f1_macro)
        history['val_precision'].append(val_precision)
        history['val_recall'].append(val_recall)
        history['combined_metric'].append(combined_metric)
        
        epoch_time = time.time() - epoch_start
        
        # Relatório
        print(f"\nÉpoca {epoch+1}/{epochs} - {epoch_time:.1f}s")
        print(f"  Train Loss:    {train_loss:.4f}")
        print(f"  Val Accuracy:  {val_accuracy:.4f} ({val_accuracy*100:.2f}%)")
        print(f"  Val F1-Macro:  {val_f1_macro:.4f}")
        print(f"  Val Precision: {val_precision:.4f}")
        print(f"  Val Recall:    {val_recall:.4f}")
        print(f"  Combined:      {combined_metric:.4f}")
        
        # Salvar melhor modelo
        if combined_metric > best_combined_metric:
            best_combined_metric = combined_metric
            best_val_f1 = val_f1_macro
            best_val_accuracy = val_accuracy
            best_epoch = epoch + 1
            
            model_path = os.path.join(MODELS_PATH, f"{model_name}_best.pth")
            torch.save({
                'epoch': epoch + 1,
                'model_state_dict': model.state_dict(),
                'optimizer_state_dict': optimizer.state_dict(),
                'val_f1': val_f1_macro,
                'val_accuracy': val_accuracy,
                'combined_metric': combined_metric,
            }, model_path)
            
            print(f"  ✓ Modelo salvo! (Combined: {combined_metric:.4f})")
        
        # Early Stopping
        if early_stopping(combined_metric):
            print(f"\n{'='*80}")
            print(f"EARLY STOPPING - ÉPOCA {epoch+1}")
            print(f"Melhor época: {best_epoch}")
            print(f"Melhor Combined: {best_combined_metric:.4f}")
            print(f"Melhor F1: {best_val_f1:.4f} | Melhor Acc: {best_val_accuracy:.4f}")
            print(f"{'='*80}")
            break
    
    total_time = time.time() - start_time
    
    print(f"\n{'='*80}")
    print(f"TREINAMENTO CONCLUÍDO: {model_name.upper()}")
    print(f"{'='*80}")
    print(f"Tempo total: {str(timedelta(seconds=int(total_time)))}")
    print(f"Épocas: {len(history['train_loss'])}/{epochs}")
    print(f"Melhor Combined: {best_combined_metric:.4f} (época {best_epoch})")
    print(f"Melhor F1: {best_val_f1:.4f}")
    print(f"Melhor Accuracy: {best_val_accuracy:.4f}")
    print(f"{'='*80}")
    
    return {
        'model_name': model_name,
        'history': history,
        'best_combined_metric': best_combined_metric,
        'best_val_f1': best_val_f1,
        'best_val_accuracy': best_val_accuracy,
        'best_epoch': best_epoch,
        'total_time': total_time,
        'epochs_completed': len(history['train_loss'])
    }



In [43]:
# =============================================================================
# AVALIAÇÃO FINAL
# =============================================================================

def evaluate_model(model_name, test_loader, device):
    """Avalia modelo no teste."""
    print(f"\n{'='*80}")
    print(f"AVALIAÇÃO FINAL: {model_name.upper()}")
    print(f"{'='*80}")
    
    # Carregar melhor modelo
    model = create_model(model_name)
    model_path = os.path.join(MODELS_PATH, f"{model_name}_best.pth")
    
    if not os.path.exists(model_path):
        print(f"❌ Modelo não encontrado: {model_path}")
        return None
    
    checkpoint = torch.load(model_path, map_location=device)
    model.load_state_dict(checkpoint['model_state_dict'])
    model.to(device)
    model.eval()
    
    # Predições
    test_preds, test_labels = validate(model, test_loader, device)
    
    # Métricas
    test_accuracy = accuracy_score(test_labels, test_preds)
    test_f1_macro = f1_score(test_labels, test_preds, average='macro')
    test_f1_weighted = f1_score(test_labels, test_preds, average='weighted')
    test_precision = precision_score(test_labels, test_preds, average='macro', zero_division=0)
    test_recall = recall_score(test_labels, test_preds, average='macro', zero_division=0)
    
    # Relatório por classe
    class_names = list(EMOTION_LABELS.keys())
    class_report = classification_report(
        test_labels, test_preds,
        target_names=class_names,
        output_dict=True,
        zero_division=0
    )
    
    # Matriz de confusão
    conf_matrix = confusion_matrix(test_labels, test_preds)
    
    print(f"\n📊 MÉTRICAS DE TESTE:")
    print(f"  Accuracy:         {test_accuracy:.4f} ({test_accuracy*100:.2f}%)")
    print(f"  F1-Score (Macro): {test_f1_macro:.4f}")
    print(f"  F1-Score (Wtd):   {test_f1_weighted:.4f}")
    print(f"  Precision:        {test_precision:.4f}")
    print(f"  Recall:           {test_recall:.4f}")
    
    print(f"\n📋 RELATÓRIO POR CLASSE:")
    print(classification_report(test_labels, test_preds, target_names=class_names, zero_division=0))
    
    # Plot matriz
    plt.figure(figsize=(10, 8))
    sns.heatmap(conf_matrix, annot=True, fmt='d', cmap='Blues',
                xticklabels=class_names, yticklabels=class_names)
    plt.title(f'Matriz de Confusão - {model_name.upper()}', fontsize=14, fontweight='bold')
    plt.ylabel('Classe Real')
    plt.xlabel('Classe Predita')
    plt.tight_layout()
    plt.savefig(os.path.join(PLOTS_PATH, f'confusion_matrix_{model_name}.png'), dpi=300)
    plt.close()
    
    return {
        'model_name': model_name,
        'test_accuracy': test_accuracy,
        'test_f1_macro': test_f1_macro,
        'test_f1_weighted': test_f1_weighted,
        'test_precision': test_precision,
        'test_recall': test_recall,
        'class_report': class_report,
        'confusion_matrix': conf_matrix.tolist()
    }

In [44]:
# =============================================================================
# MAIN
# =============================================================================

def main():
    """Função principal."""
    print("\n📂 Carregando datasets...")
    
    train_transform = create_train_transform()  # Augmentation 80%
    val_transform = create_val_transform()      # Sem augmentation
    
    train_dataset = EmotionDataset(
        root_dir=os.path.join(DATASET_PATH, 'train'),
        transform=train_transform
    )
    
    test_dataset = EmotionDataset(
        root_dir=os.path.join(DATASET_PATH, 'test'),
        transform=val_transform
    )
    
    print(f"✓ Dataset de treino: {len(train_dataset):,} imagens")
    print(f"✓ Dataset de teste: {len(test_dataset):,} imagens")
    print(f"⚡ Augmentation online: 80% das imagens transformadas A CADA ÉPOCA")
    
    # DataLoaders
    train_loader = DataLoader(
        train_dataset,
        batch_size=BATCH_SIZE,
        shuffle=True,
        num_workers=NUM_WORKERS,
        pin_memory=True
    )
    
    test_loader = DataLoader(
        test_dataset,
        batch_size=BATCH_SIZE,
        shuffle=False,
        num_workers=NUM_WORKERS,
        pin_memory=True
    )
    
    # Modelos
    models_to_train = ['resnet50', 'efficientnet_b0', 'vit_b_16']
    
    all_results = []
    
    # Treinar cada modelo
    for model_name in models_to_train:
        # Early stopping individual
        early_stopping = EarlyStopping(
            patience=EARLY_STOP_PATIENCE,
            min_delta=EARLY_STOP_MIN_DELTA,
            mode='max',
            verbose=True
        )
        
        # Treinar
        train_result = train_model(
            model_name,
            train_loader,
            test_loader,
            DEVICE,
            EPOCHS,
            early_stopping
        )
        
        if train_result:
            # Avaliar
            eval_result = evaluate_model(model_name, test_loader, DEVICE)
            
            if eval_result:
                # Combinar
                combined_result = {**train_result, **eval_result}
                all_results.append(combined_result)
                
                # Salvar métricas
                metrics_df = pd.DataFrame([combined_result])
                metrics_df.to_csv(
                    os.path.join(METRICS_PATH, f'metrics_{model_name}.csv'),
                    index=False
                )
        
        # Limpar memória
        if DEVICE == "cuda":
            torch.cuda.empty_cache()
            gc.collect()
    
    # Comparação
    if all_results:
        comparison_df = pd.DataFrame(all_results)
        comparison_df.to_csv(
            os.path.join(METRICS_PATH, 'comparison_all_models.csv'),
            index=False
        )
        
        print(f"\n{'='*80}")
        print("COMPARAÇÃO FINAL - AUGMENTATION ONLINE 80%")
        print(f"{'='*80}")
        print(comparison_df[['model_name', 'test_accuracy', 'test_f1_macro', 
                             'best_epoch', 'epochs_completed']].to_string(index=False))
        print(f"{'='*80}")
    
    print(f"\n✅ Treinamento completo!")
    print(f"📁 Resultados: {RESULTS_PATH}")


In [45]:
# =============================================================================
# ENTRADA
# =============================================================================

if __name__ == "__main__":
    try:
        main()
    except KeyboardInterrupt:
        print("\n\n⚠️ Interrompido pelo usuário")
    except Exception as e:
        print(f"\n\n❌ ERRO: {type(e).__name__}")
        print(f"Mensagem: {e}")
        
        import traceback
        print("\nTraceback:")
        traceback.print_exc()


📂 Carregando datasets...
✓ Dataset de treino: 33,404 imagens
✓ Dataset de teste: 3,068 imagens
⚡ Augmentation online: 80% das imagens transformadas A CADA ÉPOCA

TREINANDO: RESNET50


✓ Modelo 'resnet50' criado

Época 1/100 - 90.5s
  Train Loss:    0.9007
  Val Accuracy:  0.7588 (75.88%)
  Val F1-Macro:  0.6571
  Val Precision: 0.6557
  Val Recall:    0.6739
  Combined:      0.6978
  ✓ Modelo salvo! (Combined: 0.6978)
Early Stopping: Baseline: 0.6978

Época 2/100 - 90.4s
  Train Loss:    0.4429
  Val Accuracy:  0.7542 (75.42%)
  Val F1-Macro:  0.6636
  Val Precision: 0.6500
  Val Recall:    0.6832
  Combined:      0.6998
  ✓ Modelo salvo! (Combined: 0.6998)
Early Stopping: Melhoria! (0.6998, Δ=+0.0021)

Época 3/100 - 90.3s
  Train Loss:    0.3211
  Val Accuracy:  0.7744 (77.44%)
  Val F1-Macro:  0.6845
  Val Precision: 0.6845
  Val Recall:    0.6886
  Combined:      0.7205
  ✓ Modelo salvo! (Combined: 0.7205)
Early Stopping: Melhoria! (0.7205, Δ=+0.0206)

Época 4/100 - 90.4s
  Train Loss:    0.2563
  Val Accuracy:  0.7715 (77.15%)
  Val F1-Macro:  0.6953
  Val Precision: 0.7038
  Val Recall:    0.6948
  Combined:      0.7258
  ✓ Modelo salvo! (Combined: 0.7258)
Earl

100%|██████████| 330M/330M [01:02<00:00, 5.54MB/s] 


✓ Modelo 'vit_b_16' criado

Época 1/100 - 213.8s
  Train Loss:    1.5867
  Val Accuracy:  0.6744 (67.44%)
  Val F1-Macro:  0.5644
  Val Precision: 0.5987
  Val Recall:    0.6148
  Combined:      0.6084
  ✓ Modelo salvo! (Combined: 0.6084)
Early Stopping: Baseline: 0.6084

Época 2/100 - 214.3s
  Train Loss:    0.7242
  Val Accuracy:  0.7200 (72.00%)
  Val F1-Macro:  0.6203
  Val Precision: 0.6076
  Val Recall:    0.6685
  Combined:      0.6602
  ✓ Modelo salvo! (Combined: 0.6602)
Early Stopping: Melhoria! (0.6602, Δ=+0.0518)

Época 3/100 - 214.2s
  Train Loss:    0.4589
  Val Accuracy:  0.7510 (75.10%)
  Val F1-Macro:  0.6713
  Val Precision: 0.6899
  Val Recall:    0.6866
  Combined:      0.7032
  ✓ Modelo salvo! (Combined: 0.7032)
Early Stopping: Melhoria! (0.7032, Δ=+0.0430)

Época 4/100 - 214.1s
  Train Loss:    0.3462
  Val Accuracy:  0.7480 (74.80%)
  Val F1-Macro:  0.6605
  Val Precision: 0.6785
  Val Recall:    0.6638
  Combined:      0.6955
Early Stopping: Sem melhoria (1/15)

