# Sprint 2 - Classifica√ß√£o Bin√°ria de Doen√ßas em Folhas de Caf√©

Este notebook implementa um modelo de classifica√ß√£o bin√°ria para detectar se uma folha de caf√© √©:
- **Saud√°vel (Healthy)**: Classe 0
- **N√£o Saud√°vel (Doente)**: Classe 1 (todas as doen√ßas combinadas)

## Caracter√≠sticas do Modelo:
- **Classifica√ß√£o Bin√°ria**: Simplifica o problema para aumentar a acur√°cia
- **Data Augmentation Avan√ßada**: T√©cnicas de aumento de dados
- **Arquiteturas Customizadas**: CNNs otimizadas para classifica√ß√£o bin√°ria
- **Learning Rate Scheduling**: Ajuste din√¢mico da taxa de aprendizado
- **Early Stopping**: Preven√ß√£o de overfitting
- **M√©tricas Detalhadas**: Precision, Recall, F1-Score para cada classe
- **Visualiza√ß√£o de Steps**: Progresso detalhado durante o treinamento

## Classes Originais Convertidas:
- **Healthy** ‚Üí Saud√°vel (0)
- **Cerscospora** ‚Üí N√£o Saud√°vel (1)
- **Leaf rust** ‚Üí N√£o Saud√°vel (1)
- **Miner** ‚Üí N√£o Saud√°vel (1)
- **Phoma** ‚Üí N√£o Saud√°vel (1)

## 1. Download dos Datasets

In [None]:
# Download dos Datasets do Kaggle
import kagglehub

print("üì• Baixando datasets de doen√ßas em folhas de caf√©...")

# Download dos datasets
noamaanabdulazeem_jmuben_coffee_dataset_path = kagglehub.dataset_download('noamaanabdulazeem/jmuben-coffee-dataset')
gauravduttakiit_coffee_leaf_diseases_path = kagglehub.dataset_download('gauravduttakiit/coffee-leaf-diseases')
biniyamyoseph_ethiopian_coffee_leaf_disease_path = kagglehub.dataset_download('biniyamyoseph/ethiopian-coffee-leaf-disease')
mohammedzwaughfa_coffee_leaf_disease_dataset_path = kagglehub.dataset_download('mohammedzwaughfa/coffee-leaf-disease-dataset')

print('‚úÖ Download dos datasets conclu√≠do!')
print(f"Dataset 1: {noamaanabdulazeem_jmuben_coffee_dataset_path}")
print(f"Dataset 2: {gauravduttakiit_coffee_leaf_diseases_path}")
print(f"Dataset 3: {biniyamyoseph_ethiopian_coffee_leaf_disease_path}")
print(f"Dataset 4: {mohammedzwaughfa_coffee_leaf_disease_dataset_path}")

## 2. Importa√ß√£o de Bibliotecas

In [None]:
# Import necessary libraries
import os
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torch.optim.lr_scheduler import CosineAnnealingLR, ReduceLROnPlateau, StepLR
import torchvision
from torchvision import datasets, transforms, models
from torch.utils.data import DataLoader, ConcatDataset, random_split, Dataset
import numpy as np
import matplotlib.pyplot as plt
import seaborn as sns
from PIL import Image
import pandas as pd
from pathlib import Path
import time
import json
from collections import defaultdict, Counter
import warnings
from tqdm import tqdm
import random
from sklearn.metrics import accuracy_score, precision_score, recall_score, f1_score, classification_report, confusion_matrix, roc_curve, auc
warnings.filterwarnings('ignore')

# Set device
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"Using device: {device}")
if torch.cuda.is_available():
    print(f"GPU: {torch.cuda.get_device_name(0)}")
    print(f"CUDA Version: {torch.version.cuda}")
    print(f"Memory: {torch.cuda.get_device_properties(0).total_memory / 1024**3:.1f} GB")

# Set random seeds for reproducibility
torch.manual_seed(42)
np.random.seed(42)
random.seed(42)
if torch.cuda.is_available():
    torch.cuda.manual_seed(42)
    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.benchmark = False

## 3. Configura√ß√£o do Modelo

In [None]:
# Configuration for Binary Classification
CONFIG = {
    'img_size': 224,
    'batch_size': 32,
    'num_epochs': 50,
    'learning_rate': 0.001,
    'weight_decay': 1e-4,
    'patience': 10,  # For early stopping
    'min_delta': 0.001,  # Minimum change to qualify as improvement
    'subset_fraction': 0.8,  # Use 80% of data
    
    # Data paths
    'save_dir': './modulo_final/sprint_2/modelos',  # For saving models
    
    # Class information - BINARY CLASSIFICATION
    'class_names': ['Saud√°vel', 'N√£o Saud√°vel'],
    'num_classes': 2,
    
    # Model configurations
    'models_to_train': ['BinaryCNN_Light', 'BinaryCNN_Deep', 'BinaryCNN_Efficient'],
    
    # Training parameters
    'use_mixed_precision': True,
    'gradient_clip_norm': 1.0,
    'scheduler_type': 'cosine',
    
    # Data augmentation
    'use_advanced_augmentation': True,
    'mixup_alpha': 0.2,
    'cutmix_alpha': 1.0,
    'label_smoothing': 0.1
}

# Create save directory
os.makedirs(CONFIG['save_dir'], exist_ok=True)

print("Configuration loaded successfully!")
print(f"Number of classes: {CONFIG['num_classes']}")
print(f"Classes: {CONFIG['class_names']}")
print(f"Models to train: {CONFIG['models_to_train']}")
print(f"Save directory: {CONFIG['save_dir']}")

## 4. Dataset Customizado para Classifica√ß√£o Bin√°ria

In [None]:
class BinaryDataset(Dataset):
    """Dataset wrapper that converts multi-class to binary classification"""
    def __init__(self, original_dataset, healthy_class_idx=1):
        """
        Args:
            original_dataset: The original multi-class dataset
            healthy_class_idx: Index of the 'Healthy' class in original dataset
        """
        self.dataset = original_dataset
        self.healthy_class_idx = healthy_class_idx
        
    def __len__(self):
        return len(self.dataset)
    
    def __getitem__(self, idx):
        image, label = self.dataset[idx]
        # Convert to binary: 0 = Healthy, 1 = Not Healthy
        binary_label = 0 if label == self.healthy_class_idx else 1
        return image, binary_label

print("Binary dataset class created successfully!")

## 5. Data Augmentation

In [None]:
# Advanced Data Augmentation
class AdvancedTransforms:
    """Advanced data augmentation techniques for better model generalization"""
    
    @staticmethod
    def get_train_transforms(img_size=224):
        """Get training transforms with advanced augmentation"""
        return transforms.Compose([
            transforms.Resize((img_size + 32, img_size + 32)),
            transforms.RandomCrop(img_size),
            transforms.RandomHorizontalFlip(p=0.5),
            transforms.RandomVerticalFlip(p=0.2),
            transforms.RandomRotation(degrees=15),
            transforms.ColorJitter(brightness=0.2, contrast=0.2, saturation=0.2, hue=0.1),
            transforms.RandomAffine(degrees=0, translate=(0.1, 0.1), scale=(0.9, 1.1)),
            transforms.RandomPerspective(distortion_scale=0.2, p=0.3),
            transforms.ToTensor(),
            transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]),
            transforms.RandomErasing(p=0.2, scale=(0.02, 0.33), ratio=(0.3, 3.3))
        ])
    
    @staticmethod
    def get_val_transforms(img_size=224):
        """Get validation transforms (minimal augmentation)"""
        return transforms.Compose([
            transforms.Resize((img_size, img_size)),
            transforms.ToTensor(),
            transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
        ])

# Mixup and CutMix implementations
def mixup_data(x, y, alpha=1.0):
    """Apply mixup augmentation"""
    if alpha > 0:
        lam = np.random.beta(alpha, alpha)
    else:
        lam = 1
    
    batch_size = x.size(0)
    index = torch.randperm(batch_size).to(x.device)
    
    mixed_x = lam * x + (1 - lam) * x[index, :]
    y_a, y_b = y, y[index]
    return mixed_x, y_a, y_b, lam

def mixup_criterion(criterion, pred, y_a, y_b, lam):
    """Mixup loss function"""
    return lam * criterion(pred, y_a) + (1 - lam) * criterion(pred, y_b)

def cutmix_data(x, y, alpha=1.0):
    """Apply CutMix augmentation"""
    if alpha > 0:
        lam = np.random.beta(alpha, alpha)
    else:
        lam = 1
    
    batch_size = x.size(0)
    index = torch.randperm(batch_size).to(x.device)
    
    y_a, y_b = y, y[index]
    bbx1, bby1, bbx2, bby2 = rand_bbox(x.size(), lam)
    x[:, :, bbx1:bbx2, bby1:bby2] = x[index, :, bbx1:bbx2, bby1:bby2]
    
    # Adjust lambda
    lam = 1 - ((bbx2 - bbx1) * (bby2 - bby1) / (x.size()[-1] * x.size()[-2]))
    return x, y_a, y_b, lam

def rand_bbox(size, lam):
    """Generate random bounding box for CutMix"""
    W = size[2]
    H = size[3]
    cut_rat = np.sqrt(1. - lam)
    cut_w = int(W * cut_rat)
    cut_h = int(H * cut_rat)
    
    cx = np.random.randint(W)
    cy = np.random.randint(H)
    
    bbx1 = np.clip(cx - cut_w // 2, 0, W)
    bby1 = np.clip(cy - cut_h // 2, 0, H)
    bbx2 = np.clip(cx + cut_w // 2, 0, W)
    bby2 = np.clip(cy + cut_h // 2, 0, H)
    
    return bbx1, bby1, bbx2, bby2

print("Advanced data augmentation functions loaded successfully!")

## 6. Arquiteturas de CNN para Classifica√ß√£o Bin√°ria

In [None]:
# Binary CNN Architectures
class BinaryCNN_Light(nn.Module):
    """Lightweight Binary CNN - Fast and efficient"""
    def __init__(self):
        super(BinaryCNN_Light, self).__init__()
        
        self.features = nn.Sequential(
            # Block 1
            nn.Conv2d(3, 32, kernel_size=3, padding=1),
            nn.BatchNorm2d(32),
            nn.ReLU(inplace=True),
            nn.MaxPool2d(2, 2),
            
            # Block 2
            nn.Conv2d(32, 64, kernel_size=3, padding=1),
            nn.BatchNorm2d(64),
            nn.ReLU(inplace=True),
            nn.MaxPool2d(2, 2),
            
            # Block 3
            nn.Conv2d(64, 128, kernel_size=3, padding=1),
            nn.BatchNorm2d(128),
            nn.ReLU(inplace=True),
            nn.MaxPool2d(2, 2),
            
            # Block 4
            nn.Conv2d(128, 256, kernel_size=3, padding=1),
            nn.BatchNorm2d(256),
            nn.ReLU(inplace=True),
            nn.MaxPool2d(2, 2),
        )
        
        self.avgpool = nn.AdaptiveAvgPool2d((1, 1))
        
        self.classifier = nn.Sequential(
            nn.Dropout(0.5),
            nn.Linear(256, 128),
            nn.ReLU(inplace=True),
            nn.Dropout(0.3),
            nn.Linear(128, 2)  # Binary classification
        )
    
    def forward(self, x):
        x = self.features(x)
        x = self.avgpool(x)
        x = torch.flatten(x, 1)
        x = self.classifier(x)
        return x

class BinaryCNN_Deep(nn.Module):
    """Deeper Binary CNN with residual connections"""
    def __init__(self):
        super(BinaryCNN_Deep, self).__init__()
        
        # Use ResNet-like structure
        self.conv1 = nn.Conv2d(3, 64, kernel_size=7, stride=2, padding=3)
        self.bn1 = nn.BatchNorm2d(64)
        self.relu = nn.ReLU(inplace=True)
        self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1)
        
        self.layer1 = self._make_layer(64, 64, 2)
        self.layer2 = self._make_layer(64, 128, 2, stride=2)
        self.layer3 = self._make_layer(128, 256, 2, stride=2)
        self.layer4 = self._make_layer(256, 512, 2, stride=2)
        
        self.avgpool = nn.AdaptiveAvgPool2d((1, 1))
        
        self.classifier = nn.Sequential(
            nn.Dropout(0.5),
            nn.Linear(512, 128),
            nn.ReLU(inplace=True),
            nn.Dropout(0.3),
            nn.Linear(128, 2)  # Binary classification
        )
    
    def _make_layer(self, inplanes, planes, blocks, stride=1):
        layers = []
        layers.append(BasicBlock(inplanes, planes, stride))
        for _ in range(1, blocks):
            layers.append(BasicBlock(planes, planes))
        return nn.Sequential(*layers)
    
    def forward(self, x):
        x = self.conv1(x)
        x = self.bn1(x)
        x = self.relu(x)
        x = self.maxpool(x)
        
        x = self.layer1(x)
        x = self.layer2(x)
        x = self.layer3(x)
        x = self.layer4(x)
        
        x = self.avgpool(x)
        x = torch.flatten(x, 1)
        x = self.classifier(x)
        return x

class BasicBlock(nn.Module):
    """Basic residual block"""
    def __init__(self, inplanes, planes, stride=1):
        super(BasicBlock, self).__init__()
        self.conv1 = nn.Conv2d(inplanes, planes, kernel_size=3, stride=stride, padding=1, bias=False)
        self.bn1 = nn.BatchNorm2d(planes)
        self.conv2 = nn.Conv2d(planes, planes, kernel_size=3, stride=1, padding=1, bias=False)
        self.bn2 = nn.BatchNorm2d(planes)
        
        self.shortcut = nn.Sequential()
        if stride != 1 or inplanes != planes:
            self.shortcut = nn.Sequential(
                nn.Conv2d(inplanes, planes, kernel_size=1, stride=stride, bias=False),
                nn.BatchNorm2d(planes)
            )
    
    def forward(self, x):
        out = F.relu(self.bn1(self.conv1(x)))
        out = self.bn2(self.conv2(out))
        out += self.shortcut(x)
        out = F.relu(out)
        return out

class BinaryCNN_Efficient(nn.Module):
    """Efficient Binary CNN - Balanced speed and accuracy"""
    def __init__(self):
        super(BinaryCNN_Efficient, self).__init__()
        
        # Depthwise separable convolutions for efficiency
        self.features = nn.Sequential(
            # Stem
            nn.Conv2d(3, 32, kernel_size=3, stride=2, padding=1),
            nn.BatchNorm2d(32),
            nn.ReLU(inplace=True),
            
            # Depthwise separable blocks
            self._depthwise_separable(32, 64, stride=1),
            self._depthwise_separable(64, 128, stride=2),
            self._depthwise_separable(128, 256, stride=2),
            self._depthwise_separable(256, 512, stride=2),
        )
        
        self.avgpool = nn.AdaptiveAvgPool2d((1, 1))
        
        self.classifier = nn.Sequential(
            nn.Dropout(0.3),
            nn.Linear(512, 2)  # Binary classification
        )
    
    def _depthwise_separable(self, in_channels, out_channels, stride=1):
        return nn.Sequential(
            # Depthwise
            nn.Conv2d(in_channels, in_channels, kernel_size=3, stride=stride, 
                     padding=1, groups=in_channels, bias=False),
            nn.BatchNorm2d(in_channels),
            nn.ReLU(inplace=True),
            # Pointwise
            nn.Conv2d(in_channels, out_channels, kernel_size=1, bias=False),
            nn.BatchNorm2d(out_channels),
            nn.ReLU(inplace=True),
        )
    
    def forward(self, x):
        x = self.features(x)
        x = self.avgpool(x)
        x = torch.flatten(x, 1)
        x = self.classifier(x)
        return x

print("Binary CNN architectures loaded successfully!")

## 7. Carregamento e Prepara√ß√£o dos Dados

In [None]:
# Dataset paths
DATASET_PATHS = [
    os.path.join(mohammedzwaughfa_coffee_leaf_disease_dataset_path, "dataset/test"),
    os.path.join(gauravduttakiit_coffee_leaf_diseases_path, "train"),
    os.path.join(gauravduttakiit_coffee_leaf_diseases_path, "test"),
    os.path.join(biniyamyoseph_ethiopian_coffee_leaf_disease_path, "ethiopian cofee leaf dataset/train aug"),
    os.path.join(biniyamyoseph_ethiopian_coffee_leaf_disease_path, "ethiopian cofee leaf dataset/test"),
    os.path.join(noamaanabdulazeem_jmuben_coffee_dataset_path, "JMuBEN"),
]

print("üîç Loading and preparing datasets for binary classification...")

# Load datasets with transforms
train_transform = AdvancedTransforms.get_train_transforms(CONFIG['img_size'])
val_transform = AdvancedTransforms.get_val_transforms(CONFIG['img_size'])

def load_datasets(paths, transform):
    """Load all datasets and combine them"""
    datasets_list = []
    for path in paths:
        if os.path.exists(path):
            ds = datasets.ImageFolder(path, transform=transform)
            datasets_list.append(ds)
            print(f"  Loaded: {path.split('/')[-1]} - {len(ds)} samples")
    return ConcatDataset(datasets_list) if datasets_list else None

# Load combined dataset
combined_dataset = load_datasets(DATASET_PATHS, train_transform)

if combined_dataset is None:
    raise ValueError("No datasets were loaded. Please check the dataset paths.")

# Find the index of 'Healthy' class in the original dataset
# Assuming the first dataset has the class structure
sample_dataset = datasets.ImageFolder(DATASET_PATHS[0], transform=train_transform)
original_classes = sample_dataset.classes
print(f"\nOriginal classes: {original_classes}")

# Find healthy class index (usually 'Healthy' or similar)
healthy_idx = None
for idx, class_name in enumerate(original_classes):
    if 'healthy' in class_name.lower():
        healthy_idx = idx
        print(f"Found 'Healthy' class at index {healthy_idx}")
        break

if healthy_idx is None:
    print("Warning: 'Healthy' class not found. Using index 1 as default.")
    healthy_idx = 1

# Convert to binary dataset
binary_dataset = BinaryDataset(combined_dataset, healthy_class_idx=healthy_idx)

# Use subset for faster training
subset_size = int(len(binary_dataset) * CONFIG['subset_fraction'])
binary_dataset, _ = random_split(binary_dataset, [subset_size, len(binary_dataset) - subset_size])

# Split into train and validation
train_size = int(0.7 * len(binary_dataset))
val_size = len(binary_dataset) - train_size

train_data, val_data = random_split(binary_dataset, [train_size, val_size])

# Create data loaders
train_loader = DataLoader(train_data, batch_size=CONFIG['batch_size'], shuffle=True, num_workers=4, pin_memory=True)
val_loader = DataLoader(val_data, batch_size=CONFIG['batch_size'], shuffle=False, num_workers=4, pin_memory=True)

print(f"\n‚úÖ Dataset preparation completed!")
print(f"Total samples: {len(binary_dataset)}")
print(f"Training samples: {train_size}")
print(f"Validation samples: {val_size}")
print(f"Training batches: {len(train_loader)}")
print(f"Validation batches: {len(val_loader)}")
print(f"\nBinary classes: {CONFIG['class_names']}")

## 8. Fun√ß√µes de Treinamento e Avalia√ß√£o

In [None]:
class EarlyStopping:
    """Early stopping to prevent overfitting"""
    def __init__(self, patience=7, min_delta=0, restore_best_weights=True):
        self.patience = patience
        self.min_delta = min_delta
        self.restore_best_weights = restore_best_weights
        self.best_loss = None
        self.counter = 0
        self.best_weights = None
        
    def __call__(self, val_loss, model):
        if self.best_loss is None:
            self.best_loss = val_loss
            self.save_checkpoint(model)
        elif val_loss < self.best_loss - self.min_delta:
            self.best_loss = val_loss
            self.counter = 0
            self.save_checkpoint(model)
        else:
            self.counter += 1
            
        if self.counter >= self.patience:
            if self.restore_best_weights:
                model.load_state_dict(self.best_weights)
            return True
        return False
    
    def save_checkpoint(self, model):
        self.best_weights = {k: v.cpu().clone() for k, v in model.state_dict().items()}

class MetricsTracker:
    """Track training metrics"""
    def __init__(self):
        self.reset()
    
    def reset(self):
        self.train_losses = []
        self.val_losses = []
        self.train_accuracies = []
        self.val_accuracies = []
        self.train_precisions = []
        self.val_precisions = []
        self.train_recalls = []
        self.val_recalls = []
        self.train_f1s = []
        self.val_f1s = []
        self.learning_rates = []
        self.epochs = []
    
    def update(self, epoch, train_loss, val_loss, train_metrics, val_metrics, lr):
        self.epochs.append(epoch)
        self.train_losses.append(train_loss)
        self.val_losses.append(val_loss)
        self.train_accuracies.append(train_metrics['accuracy'])
        self.val_accuracies.append(val_metrics['accuracy'])
        self.train_precisions.append(train_metrics['precision'])
        self.val_precisions.append(val_metrics['precision'])
        self.train_recalls.append(train_metrics['recall'])
        self.val_recalls.append(val_metrics['recall'])
        self.train_f1s.append(train_metrics['f1'])
        self.val_f1s.append(val_metrics['f1'])
        self.learning_rates.append(lr)
    
    def get_best_epoch(self):
        best_idx = np.argmax(self.val_accuracies)
        return self.epochs[best_idx], self.val_accuracies[best_idx]

def calculate_metrics(all_preds, all_targets):
    """Calculate binary classification metrics"""
    accuracy = accuracy_score(all_targets, all_preds)
    precision = precision_score(all_targets, all_preds, average='binary', zero_division=0)
    recall = recall_score(all_targets, all_preds, average='binary', zero_division=0)
    f1 = f1_score(all_targets, all_preds, average='binary', zero_division=0)
    
    return {
        'accuracy': accuracy * 100,
        'precision': precision * 100,
        'recall': recall * 100,
        'f1': f1 * 100
    }

def train_and_evaluate_model(model, model_name, train_loader, val_loader, config):
    """Train and evaluate binary classification model with detailed progress"""
    print(f"\n{'='*80}")
    print(f"üîß Training: {model_name}")
    print(f"{'='*80}")
    
    metrics = MetricsTracker()
    early_stopping = EarlyStopping(patience=config['patience'], min_delta=config['min_delta'])
    
    criterion = nn.CrossEntropyLoss(label_smoothing=config['label_smoothing'])
    optimizer = optim.AdamW(model.parameters(), lr=config['learning_rate'], weight_decay=config['weight_decay'])
    
    if config['scheduler_type'] == 'cosine':
        scheduler = CosineAnnealingLR(optimizer, T_max=config['num_epochs'])
    elif config['scheduler_type'] == 'plateau':
        scheduler = ReduceLROnPlateau(optimizer, mode='min', patience=5, factor=0.5)
    else:
        scheduler = StepLR(optimizer, step_size=15, gamma=0.1)
    
    scaler = torch.cuda.amp.GradScaler() if config['use_mixed_precision'] and torch.cuda.is_available() else None
    
    model.to(device)
    
    print(f"\nüìä Model Configuration:")
    total_params = sum(p.numel() for p in model.parameters())
    print(f"  Total parameters: {total_params:,}")
    print(f"  Epochs: {config['num_epochs']}")
    print(f"  Learning rate: {config['learning_rate']}")
    print(f"  Batch size: {config['batch_size']}")
    print(f"  Scheduler: {config['scheduler_type']}")
    print(f"  Mixed precision: {config['use_mixed_precision']}")
    print(f"\nüöÄ Starting training...\n")
    
    start_time = time.time()
    
    for epoch in range(config['num_epochs']):
        # Training phase
        model.train()
        train_loss = 0.0
        all_train_preds = []
        all_train_targets = []
        
        train_pbar = tqdm(train_loader, desc=f"Epoch {epoch+1}/{config['num_epochs']} [Train]")
        for batch_idx, (data, target) in enumerate(train_pbar):
            data, target = data.to(device), target.to(device)
            
            optimizer.zero_grad()
            
            # Apply augmentation
            if config['use_advanced_augmentation'] and np.random.random() < 0.5:
                if np.random.random() < 0.5:
                    data, target_a, target_b, lam = mixup_data(data, target, config['mixup_alpha'])
                    if scaler:
                        with torch.cuda.amp.autocast():
                            output = model(data)
                            loss = mixup_criterion(criterion, output, target_a, target_b, lam)
                    else:
                        output = model(data)
                        loss = mixup_criterion(criterion, output, target_a, target_b, lam)
                else:
                    data, target_a, target_b, lam = cutmix_data(data, target, config['cutmix_alpha'])
                    if scaler:
                        with torch.cuda.amp.autocast():
                            output = model(data)
                            loss = mixup_criterion(criterion, output, target_a, target_b, lam)
                    else:
                        output = model(data)
                        loss = mixup_criterion(criterion, output, target_a, target_b, lam)
            else:
                if scaler:
                    with torch.cuda.amp.autocast():
                        output = model(data)
                        loss = criterion(output, target)
                else:
                    output = model(data)
                    loss = criterion(output, target)
            
            if scaler:
                scaler.scale(loss).backward()
                scaler.unscale_(optimizer)
                torch.nn.utils.clip_grad_norm_(model.parameters(), config['gradient_clip_norm'])
                scaler.step(optimizer)
                scaler.update()
            else:
                loss.backward()
                torch.nn.utils.clip_grad_norm_(model.parameters(), config['gradient_clip_norm'])
                optimizer.step()
            
            train_loss += loss.item()
            _, predicted = torch.max(output.data, 1)
            all_train_preds.extend(predicted.cpu().numpy())
            all_train_targets.extend(target.cpu().numpy())
            
            # Update progress bar
            train_pbar.set_postfix({'loss': f'{loss.item():.4f}'})
        
        # Validation phase
        model.eval()
        val_loss = 0.0
        all_val_preds = []
        all_val_targets = []
        
        val_pbar = tqdm(val_loader, desc=f"Epoch {epoch+1}/{config['num_epochs']} [Val]  ")
        with torch.no_grad():
            for data, target in val_pbar:
                data, target = data.to(device), target.to(device)
                output = model(data)
                loss = criterion(output, target)
                
                val_loss += loss.item()
                _, predicted = torch.max(output.data, 1)
                all_val_preds.extend(predicted.cpu().numpy())
                all_val_targets.extend(target.cpu().numpy())
                
                val_pbar.set_postfix({'loss': f'{loss.item():.4f}'})
        
        # Calculate metrics
        train_loss /= len(train_loader)
        val_loss /= len(val_loader)
        
        train_metrics = calculate_metrics(all_train_preds, all_train_targets)
        val_metrics = calculate_metrics(all_val_preds, all_val_targets)
        
        current_lr = optimizer.param_groups[0]['lr']
        
        metrics.update(epoch, train_loss, val_loss, train_metrics, val_metrics, current_lr)
        
        # Print epoch summary
        print(f"\n{'‚îÄ'*80}")
        print(f"Epoch {epoch+1}/{config['num_epochs']} Summary:")
        print(f"{'‚îÄ'*80}")
        print(f"Train Loss: {train_loss:.4f} | Val Loss: {val_loss:.4f}")
        print(f"Train Acc: {train_metrics['accuracy']:.2f}% | Val Acc: {val_metrics['accuracy']:.2f}%")
        print(f"Train Prec: {train_metrics['precision']:.2f}% | Val Prec: {val_metrics['precision']:.2f}%")
        print(f"Train Recall: {train_metrics['recall']:.2f}% | Val Recall: {val_metrics['recall']:.2f}%")
        print(f"Train F1: {train_metrics['f1']:.2f}% | Val F1: {val_metrics['f1']:.2f}%")
        print(f"Learning Rate: {current_lr:.6f}")
        print(f"{'‚îÄ'*80}\n")
        
        # Learning rate scheduling
        if config['scheduler_type'] == 'plateau':
            scheduler.step(val_loss)
        else:
            scheduler.step()
        
        # Early stopping
        if early_stopping(val_loss, model):
            print(f"\n‚ö†Ô∏è  Early stopping triggered at epoch {epoch+1}")
            break
    
    training_time = time.time() - start_time
    best_epoch, best_val_acc = metrics.get_best_epoch()
    
    print(f"\n{'='*80}")
    print(f"‚úÖ Training Completed!")
    print(f"{'='*80}")
    print(f"Total training time: {training_time/60:.2f} minutes")
    print(f"Best validation accuracy: {best_val_acc:.2f}% at epoch {best_epoch+1}")
    print(f"{'='*80}\n")
    
    return model, metrics

print("Training and evaluation functions loaded successfully!")

## 9. Fun√ß√µes de Visualiza√ß√£o

In [None]:
def plot_training_curves(metrics, model_name, save_path=None):
    """Plot comprehensive training curves for binary classification"""
    fig, axes = plt.subplots(2, 3, figsize=(18, 10))
    fig.suptitle(f'{model_name} - Binary Classification Training Progress', fontsize=16, fontweight='bold')
    
    # Loss curves
    axes[0, 0].plot(metrics.epochs, metrics.train_losses, 'b-', label='Training Loss', linewidth=2)
    axes[0, 0].plot(metrics.epochs, metrics.val_losses, 'r-', label='Validation Loss', linewidth=2)
    axes[0, 0].set_title('Loss Curves')
    axes[0, 0].set_xlabel('Epoch')
    axes[0, 0].set_ylabel('Loss')
    axes[0, 0].legend()
    axes[0, 0].grid(True, alpha=0.3)
    
    # Accuracy curves
    axes[0, 1].plot(metrics.epochs, metrics.train_accuracies, 'b-', label='Training Accuracy', linewidth=2)
    axes[0, 1].plot(metrics.epochs, metrics.val_accuracies, 'r-', label='Validation Accuracy', linewidth=2)
    axes[0, 1].set_title('Accuracy Curves')
    axes[0, 1].set_xlabel('Epoch')
    axes[0, 1].set_ylabel('Accuracy (%)')
    axes[0, 1].legend()
    axes[0, 1].grid(True, alpha=0.3)
    
    # Precision curves
    axes[0, 2].plot(metrics.epochs, metrics.train_precisions, 'b-', label='Training Precision', linewidth=2)
    axes[0, 2].plot(metrics.epochs, metrics.val_precisions, 'r-', label='Validation Precision', linewidth=2)
    axes[0, 2].set_title('Precision Curves')
    axes[0, 2].set_xlabel('Epoch')
    axes[0, 2].set_ylabel('Precision (%)')
    axes[0, 2].legend()
    axes[0, 2].grid(True, alpha=0.3)
    
    # Recall curves
    axes[1, 0].plot(metrics.epochs, metrics.train_recalls, 'b-', label='Training Recall', linewidth=2)
    axes[1, 0].plot(metrics.epochs, metrics.val_recalls, 'r-', label='Validation Recall', linewidth=2)
    axes[1, 0].set_title('Recall Curves')
    axes[1, 0].set_xlabel('Epoch')
    axes[1, 0].set_ylabel('Recall (%)')
    axes[1, 0].legend()
    axes[1, 0].grid(True, alpha=0.3)
    
    # F1-Score curves
    axes[1, 1].plot(metrics.epochs, metrics.train_f1s, 'b-', label='Training F1-Score', linewidth=2)
    axes[1, 1].plot(metrics.epochs, metrics.val_f1s, 'r-', label='Validation F1-Score', linewidth=2)
    axes[1, 1].set_title('F1-Score Curves')
    axes[1, 1].set_xlabel('Epoch')
    axes[1, 1].set_ylabel('F1-Score (%)')
    axes[1, 1].legend()
    axes[1, 1].grid(True, alpha=0.3)
    
    # Learning rate curve
    axes[1, 2].plot(metrics.epochs, metrics.learning_rates, 'g-', linewidth=2)
    axes[1, 2].set_title('Learning Rate Schedule')
    axes[1, 2].set_xlabel('Epoch')
    axes[1, 2].set_ylabel('Learning Rate')
    axes[1, 2].set_yscale('log')
    axes[1, 2].grid(True, alpha=0.3)
    
    plt.tight_layout()
    
    if save_path:
        plt.savefig(save_path, dpi=300, bbox_inches='tight')
        print(f"‚úÖ Training curves saved to {save_path}")
    
    plt.show()

def plot_confusion_matrix(model, data_loader, class_names, model_name, save_path=None):
    """Plot confusion matrix for binary classification"""
    model.eval()
    all_preds = []
    all_targets = []
    all_probs = []
    
    with torch.no_grad():
        for data, target in tqdm(data_loader, desc="Calculating confusion matrix"):
            data, target = data.to(device), target.to(device)
            output = model(data)
            probs = F.softmax(output, dim=1)
            _, predicted = torch.max(output, 1)
            
            all_preds.extend(predicted.cpu().numpy())
            all_targets.extend(target.cpu().numpy())
            all_probs.extend(probs[:, 1].cpu().numpy())  # Probability of class 1
    
    # Create confusion matrix
    cm = confusion_matrix(all_targets, all_preds)
    
    # Plot confusion matrix and ROC curve
    fig, axes = plt.subplots(1, 2, figsize=(15, 6))
    fig.suptitle(f'{model_name} - Binary Classification Performance', fontsize=16, fontweight='bold')
    
    # Confusion matrix
    sns.heatmap(cm, annot=True, fmt='d', cmap='Blues', 
                xticklabels=class_names, yticklabels=class_names, ax=axes[0])
    axes[0].set_title('Confusion Matrix')
    axes[0].set_xlabel('Predicted')
    axes[0].set_ylabel('Actual')
    
    # ROC curve
    fpr, tpr, _ = roc_curve(all_targets, all_probs)
    roc_auc = auc(fpr, tpr)
    
    axes[1].plot(fpr, tpr, color='darkorange', lw=2, label=f'ROC curve (AUC = {roc_auc:.2f})')
    axes[1].plot([0, 1], [0, 1], color='navy', lw=2, linestyle='--', label='Random')
    axes[1].set_xlim([0.0, 1.0])
    axes[1].set_ylim([0.0, 1.05])
    axes[1].set_xlabel('False Positive Rate')
    axes[1].set_ylabel('True Positive Rate')
    axes[1].set_title('ROC Curve')
    axes[1].legend(loc="lower right")
    axes[1].grid(True, alpha=0.3)
    
    plt.tight_layout()
    
    if save_path:
        plt.savefig(save_path, dpi=300, bbox_inches='tight')
        print(f"‚úÖ Confusion matrix saved to {save_path}")
    
    plt.show()
    
    # Print classification report
    print(f"\n{'='*80}")
    print(f"{model_name} - Classification Report")
    print(f"{'='*80}")
    print(classification_report(all_targets, all_preds, target_names=class_names))
    print(f"ROC-AUC Score: {roc_auc:.4f}")
    print(f"{'='*80}\n")

def plot_model_comparison(all_metrics, save_path=None):
    """Plot comparison between different models"""
    fig, axes = plt.subplots(2, 2, figsize=(15, 12))
    fig.suptitle('Sprint 2 - Binary Classification Model Comparison', fontsize=16, fontweight='bold')
    
    model_names = list(all_metrics.keys())
    
    # Best validation accuracy
    best_accuracies = [np.max(metrics.val_accuracies) for metrics in all_metrics.values()]
    bars = axes[0, 0].bar(model_names, best_accuracies, color=['skyblue', 'lightcoral', 'lightgreen'])
    axes[0, 0].set_title('Best Validation Accuracy')
    axes[0, 0].set_ylabel('Accuracy (%)')
    axes[0, 0].set_ylim(0, 100)
    for bar, acc in zip(bars, best_accuracies):
        axes[0, 0].text(bar.get_x() + bar.get_width()/2, bar.get_height() + 1,
                        f'{acc:.2f}%', ha='center', va='bottom', fontweight='bold')
    axes[0, 0].grid(True, alpha=0.3)
    
    # Best F1-Score
    best_f1s = [np.max(metrics.val_f1s) for metrics in all_metrics.values()]
    bars = axes[0, 1].bar(model_names, best_f1s, color=['skyblue', 'lightcoral', 'lightgreen'])
    axes[0, 1].set_title('Best Validation F1-Score')
    axes[0, 1].set_ylabel('F1-Score (%)')
    axes[0, 1].set_ylim(0, 100)
    for bar, f1 in zip(bars, best_f1s):
        axes[0, 1].text(bar.get_x() + bar.get_width()/2, bar.get_height() + 1,
                        f'{f1:.2f}%', ha='center', va='bottom', fontweight='bold')
    axes[0, 1].grid(True, alpha=0.3)
    
    # Accuracy over time
    for model_name, metrics in all_metrics.items():
        axes[1, 0].plot(metrics.epochs, metrics.val_accuracies, label=model_name, linewidth=2)
    axes[1, 0].set_title('Validation Accuracy Over Time')
    axes[1, 0].set_xlabel('Epoch')
    axes[1, 0].set_ylabel('Accuracy (%)')
    axes[1, 0].legend()
    axes[1, 0].grid(True, alpha=0.3)
    
    # Loss over time
    for model_name, metrics in all_metrics.items():
        axes[1, 1].plot(metrics.epochs, metrics.val_losses, label=model_name, linewidth=2)
    axes[1, 1].set_title('Validation Loss Over Time')
    axes[1, 1].set_xlabel('Epoch')
    axes[1, 1].set_ylabel('Loss')
    axes[1, 1].legend()
    axes[1, 1].grid(True, alpha=0.3)
    
    plt.tight_layout()
    
    if save_path:
        plt.savefig(save_path, dpi=300, bbox_inches='tight')
        print(f"‚úÖ Model comparison saved to {save_path}")
    
    plt.show()

print("Visualization functions loaded successfully!")

## 10. Pipeline de Treinamento

In [None]:
# Main Training Pipeline
print("\n" + "="*80)
print("üöÄ Starting Sprint 2 Binary Classification Training Pipeline")
print("="*80 + "\n")

# Define binary models
models_dict = {
    "BinaryCNN_Light": BinaryCNN_Light(),
    "BinaryCNN_Deep": BinaryCNN_Deep(),
    "BinaryCNN_Efficient": BinaryCNN_Efficient(),
}

# Print model information
print("üìä Model Information:")
print("‚îÄ"*80)
for name, model in models_dict.items():
    total_params = sum(p.numel() for p in model.parameters())
    trainable_params = sum(p.numel() for p in model.parameters() if p.requires_grad)
    print(f"{name}:")
    print(f"  Total parameters: {total_params:,}")
    print(f"  Trainable parameters: {trainable_params:,}")
print("‚îÄ"*80 + "\n")

# Train all models
results = {}
all_metrics = {}

for name, model in models_dict.items():
    try:
        # Train model
        trained_model, metrics = train_and_evaluate_model(
            model, name, train_loader, val_loader, CONFIG
        )
        
        # Store results
        results[name] = trained_model
        all_metrics[name] = metrics
        
        # Save model
        model_path = os.path.join(CONFIG['save_dir'], f'{name}_best.pth')
        torch.save(trained_model.state_dict(), model_path)
        print(f"üíæ Model saved to {model_path}\n")
        
        # Plot training curves
        plot_training_curves(metrics, name, 
                           os.path.join(CONFIG['save_dir'], f'{name}_training_curves.png'))
        
        # Plot confusion matrix
        plot_confusion_matrix(trained_model, val_loader, CONFIG['class_names'], name,
                            os.path.join(CONFIG['save_dir'], f'{name}_confusion_matrix.png'))
        
    except Exception as e:
        print(f"‚ùå Error training {name}: {e}")
        import traceback
        traceback.print_exc()
        continue

print("\n" + "="*80)
print("üéâ Training Pipeline Completed!")
print("="*80)
print(f"Trained models: {list(results.keys())}")
print(f"Results saved to: {CONFIG['save_dir']}")
print("="*80 + "\n")

# Plot model comparison
if all_metrics:
    plot_model_comparison(all_metrics, 
                        os.path.join(CONFIG['save_dir'], 'model_comparison.png'))

# Final results summary
print("\n" + "="*80)
print("üìä Final Results Summary")
print("="*80)
for name, metrics in all_metrics.items():
    best_epoch, best_acc = metrics.get_best_epoch()
    best_f1_idx = np.argmax(metrics.val_f1s)
    best_f1 = metrics.val_f1s[best_f1_idx]
    print(f"\n{name}:")
    print(f"  Best Accuracy: {best_acc:.2f}% at epoch {best_epoch+1}")
    print(f"  Best F1-Score: {best_f1:.2f}% at epoch {best_f1_idx+1}")
print("="*80 + "\n")

## 11. Resumo e Conclus√µes

### Caracter√≠sticas da Classifica√ß√£o Bin√°ria:

1. **Simplifica√ß√£o do Problema**:
   - Redu√ß√£o de 5 classes para 2 classes (Saud√°vel vs N√£o Saud√°vel)
   - Aumento esperado na acur√°cia devido √† simplifica√ß√£o
   - Mais adequado para aplica√ß√µes pr√°ticas de detec√ß√£o de doen√ßas

2. **M√©tricas Espec√≠ficas**:
   - **Precision**: Importante para minimizar falsos positivos
   - **Recall**: Importante para minimizar falsos negativos
   - **F1-Score**: Balan√ßo entre Precision e Recall
   - **ROC-AUC**: Capacidade de discrimina√ß√£o do modelo

3. **Arquiteturas Otimizadas**:
   - **BinaryCNN_Light**: R√°pido e eficiente para infer√™ncia em tempo real
   - **BinaryCNN_Deep**: Maior capacidade com conex√µes residuais
   - **BinaryCNN_Efficient**: Balan√ßo entre velocidade e acur√°cia

4. **T√©cnicas Avan√ßadas**:
   - Data Augmentation (Mixup, CutMix)
   - Learning Rate Scheduling
   - Early Stopping
   - Mixed Precision Training
   - Visualiza√ß√£o detalhada de progresso

### Vantagens da Classifica√ß√£o Bin√°ria:
- ‚úÖ Maior acur√°cia esperada
- ‚úÖ Mais simples de interpretar
- ‚úÖ Menos dados necess√°rios para treinamento
- ‚úÖ Ideal para triagem inicial
- ‚úÖ Deploy mais eficiente

### Pr√≥ximos Passos:
- An√°lise detalhada das m√©tricas
- Testes com imagens reais
- Otimiza√ß√£o de hiperpar√¢metros
- Prepara√ß√£o para deploy em produ√ß√£o
- Integra√ß√£o com sistema de monitoramento