---
# RESUMO DAS MELHORIAS IMPLEMENTADAS

## Experimento 1 (Baseline)
- **Arquitetura**: U-Net com encoder ResNet50 pré-treinado (ImageNet)
- **Data Augmentation**: Básico (flip, rotação, brilho/contraste)
- **Loss**: BCE + Dice
- **Scheduler**: CosineAnnealing

## Experimento 2 (Melhoria 1)
**Foco: Pré-processamento, Data Augmentation e Regularização**

| Melhoria | Descrição |
|----------|-----------|
| **CLAHE** | Contrast Limited Adaptive Histogram Equalization no canal de luminância para realçar estruturas em imagens de fundo de olho |
| **Data Aug Médico** | ElasticTransform, GridDistortion, OpticalDistortion para simular variações anatômicas |
| **CoarseDropout** | Simula oclusões parciais para regularização |
| **Deep Supervision** | Múltiplas saídas em diferentes escalas para melhor fluxo de gradiente |
| **TTA** | Test Time Augmentation com 7 transformações para predição mais robusta |
| **Scheduler** | CosineAnnealingWarmRestarts para melhor convergência |

## Experimento 3 (Melhoria 2 - Arquitetura)
**Foco: Modificação na Topologia da Rede Neural**

| Modificação | Descrição |
|-------------|-----------|
| **Attention Gates** | Mecanismos de atenção nas skip connections que permitem ao modelo focar nas regiões relevantes, suprimindo respostas irrelevantes |
| **SE Blocks** | Squeeze-and-Excitation blocks para recalibração adaptativa dos canais de features |
| **ASPP** | Atrous Spatial Pyramid Pooling no bottleneck para captura de contexto multi-escala |
| **Focal Loss** | Adicionada à função de loss para lidar com desbalanceamento de classes |
| **Discriminative LR** | Learning rate diferenciado para encoder (menor) e decoder (maior) |
| **OneCycleLR** | Scheduler mais agressivo para melhor generalização |

---
**Nota**: O Experimento 3 implementa uma modificação substancial na arquitetura da rede, não apenas um aumento de largura ou profundidade, atendendo ao requisito de otimização na topologia da rede neural.

In [None]:
# Visualização comparativa das predições
def compare_predictions(dataset, idx=0):
    """Compara predições dos 3 modelos lado a lado"""
    img, mask = dataset[idx]
    img_tensor = img.unsqueeze(0).to(device)
    
    # Carregar modelos
    model.load_state_dict(torch.load('best_optic_disc_model.pth'))
    model.eval()
    exp2_model.load_state_dict(torch.load('best_exp2_model.pth'))
    exp2_model.eval()
    exp3_model.load_state_dict(torch.load('best_exp3_attention_unet.pth'))
    exp3_model.eval()
    
    with torch.no_grad():
        pred1 = torch.sigmoid(model(img_tensor)).cpu().squeeze().numpy()
        pred2 = torch.sigmoid(exp2_model(img_tensor)).cpu().squeeze().numpy()
        pred3 = torch.sigmoid(exp3_model(img_tensor)).cpu().squeeze().numpy()
    
    # Desnormalizar imagem
    img_np = img.numpy().transpose(1, 2, 0)
    img_np = img_np * np.array([0.229, 0.224, 0.225]) + np.array([0.485, 0.456, 0.406])
    img_np = np.clip(img_np, 0, 1)
    
    mask_np = mask.squeeze().numpy()
    
    fig, axes = plt.subplots(2, 4, figsize=(20, 10))
    
    # Linha 1: Imagem, GT e predições
    axes[0, 0].imshow(img_np)
    axes[0, 0].set_title('Imagem Original')
    
    axes[0, 1].imshow(mask_np, cmap='gray')
    axes[0, 1].set_title('Ground Truth')
    
    axes[0, 2].imshow(pred1 > 0.5, cmap='gray')
    axes[0, 2].set_title('Exp1: Baseline')
    
    axes[0, 3].imshow(pred2 > 0.5, cmap='gray')
    axes[0, 3].set_title('Exp2: CLAHE+DS')
    
    # Linha 2: Overlays
    overlay_gt = img_np.copy()
    overlay_gt[mask_np > 0.5] = overlay_gt[mask_np > 0.5] * 0.5 + np.array([0, 1, 0]) * 0.5
    axes[1, 0].imshow(overlay_gt)
    axes[1, 0].set_title('Overlay GT')
    
    overlay1 = img_np.copy()
    overlay1[pred1 > 0.5] = overlay1[pred1 > 0.5] * 0.5 + np.array([0, 0, 1]) * 0.5
    axes[1, 1].imshow(overlay1)
    axes[1, 1].set_title('Overlay Exp1')
    
    overlay2 = img_np.copy()
    overlay2[pred2 > 0.5] = overlay2[pred2 > 0.5] * 0.5 + np.array([1, 0.5, 0]) * 0.5
    axes[1, 2].imshow(overlay2)
    axes[1, 2].set_title('Overlay Exp2')
    
    overlay3 = img_np.copy()
    overlay3[pred3 > 0.5] = overlay3[pred3 > 0.5] * 0.5 + np.array([1, 0, 1]) * 0.5
    axes[1, 3].imshow(overlay3)
    axes[1, 3].set_title('Overlay Exp3: Attention')
    
    for ax in axes.flat:
        ax.axis('off')
    
    plt.suptitle('Comparação das Predições dos 3 Experimentos', fontsize=14)
    plt.tight_layout()
    plt.show()

# Visualizar algumas amostras
for idx in [0, 1, 2]:
    compare_predictions(exp3_val_dataset, idx)

In [None]:
# Gráfico comparativo de curvas de aprendizado
fig, axes = plt.subplots(2, 3, figsize=(18, 10))

# Loss
axes[0, 0].plot(history['val_loss'], label='Exp1 (Baseline)', linewidth=2)
axes[0, 0].plot(exp2_history['val_loss'], label='Exp2 (CLAHE+DS)', linewidth=2)
axes[0, 0].plot(exp3_history['val_loss'], label='Exp3 (Attention)', linewidth=2)
axes[0, 0].set_title('Validation Loss')
axes[0, 0].legend()
axes[0, 0].set_xlabel('Época')
axes[0, 0].grid(True, alpha=0.3)

# IoU
axes[0, 1].plot(history['val_iou'], label='Exp1 (Baseline)', linewidth=2)
axes[0, 1].plot(exp2_history['val_iou'], label='Exp2 (CLAHE+DS)', linewidth=2)
axes[0, 1].plot(exp3_history['val_iou'], label='Exp3 (Attention)', linewidth=2)
axes[0, 1].set_title('Validation IoU')
axes[0, 1].legend()
axes[0, 1].set_xlabel('Época')
axes[0, 1].grid(True, alpha=0.3)

# Dice
axes[0, 2].plot(history['val_dice'], label='Exp1 (Baseline)', linewidth=2)
axes[0, 2].plot(exp2_history['val_dice'], label='Exp2 (CLAHE+DS)', linewidth=2)
axes[0, 2].plot(exp3_history['val_dice'], label='Exp3 (Attention)', linewidth=2)
axes[0, 2].set_title('Validation Dice')
axes[0, 2].legend()
axes[0, 2].set_xlabel('Época')
axes[0, 2].grid(True, alpha=0.3)

# Barplot comparativo
experiments = ['Exp1\n(Baseline)', 'Exp2\n(CLAHE+DS)', 'Exp2\n(+TTA)', 'Exp3\n(Attention)', 'Exp3\n(+TTA)']
dice_scores = [np.mean(all_dice), np.mean(all_dice_no_tta), np.mean(all_dice_tta), 
               np.mean(exp3_dice_no_tta), np.mean(exp3_dice_tta)]
iou_scores = [np.mean(all_iou), np.mean(all_iou_no_tta), np.mean(all_iou_tta),
              np.mean(exp3_iou_no_tta), np.mean(exp3_iou_tta)]

x = np.arange(len(experiments))
width = 0.35

axes[1, 0].bar(x - width/2, dice_scores, width, label='Dice', color='steelblue')
axes[1, 0].bar(x + width/2, iou_scores, width, label='IoU', color='coral')
axes[1, 0].set_ylabel('Score')
axes[1, 0].set_title('Comparação Final - Métricas')
axes[1, 0].set_xticks(x)
axes[1, 0].set_xticklabels(experiments)
axes[1, 0].legend()
axes[1, 0].grid(True, alpha=0.3, axis='y')
axes[1, 0].set_ylim([0.8, 1.0])

# Boxplot Dice
dice_data = [all_dice, all_dice_no_tta, all_dice_tta, exp3_dice_no_tta, exp3_dice_tta]
bp = axes[1, 1].boxplot(dice_data, labels=['Exp1', 'Exp2', 'Exp2+TTA', 'Exp3', 'Exp3+TTA'])
axes[1, 1].set_ylabel('Dice Score')
axes[1, 1].set_title('Distribuição do Dice Score')
axes[1, 1].grid(True, alpha=0.3)

# Boxplot IoU
iou_data = [all_iou, all_iou_no_tta, all_iou_tta, exp3_iou_no_tta, exp3_iou_tta]
bp = axes[1, 2].boxplot(iou_data, labels=['Exp1', 'Exp2', 'Exp2+TTA', 'Exp3', 'Exp3+TTA'])
axes[1, 2].set_ylabel('IoU')
axes[1, 2].set_title('Distribuição do IoU')
axes[1, 2].grid(True, alpha=0.3)

plt.suptitle('Comparação dos 3 Experimentos', fontsize=16, y=1.02)
plt.tight_layout()
plt.show()

In [None]:
print("="*80)
print("                    COMPARAÇÃO FINAL DOS EXPERIMENTOS")
print("="*80)

print("\n" + "-"*80)
print("Experimento 1 (Baseline): U-Net + ResNet50")
print("-"*80)
print(f"  IoU:  {np.mean(all_iou):.4f} +/- {np.std(all_iou):.4f}")
print(f"  Dice: {np.mean(all_dice):.4f} +/- {np.std(all_dice):.4f}")

print("\n" + "-"*80)
print("Experimento 2: CLAHE + Data Aug Avançado + Deep Supervision")
print("-"*80)
print("  Sem TTA:")
print(f"    IoU:  {np.mean(all_iou_no_tta):.4f} +/- {np.std(all_iou_no_tta):.4f}")
print(f"    Dice: {np.mean(all_dice_no_tta):.4f} +/- {np.std(all_dice_no_tta):.4f}")
print("  Com TTA:")
print(f"    IoU:  {np.mean(all_iou_tta):.4f} +/- {np.std(all_iou_tta):.4f}")
print(f"    Dice: {np.mean(all_dice_tta):.4f} +/- {np.std(all_dice_tta):.4f}")

print("\n" + "-"*80)
print("Experimento 3: Attention U-Net (ASPP + Attention Gates + SE Blocks)")
print("-"*80)
print("  Sem TTA:")
print(f"    IoU:  {np.mean(exp3_iou_no_tta):.4f} +/- {np.std(exp3_iou_no_tta):.4f}")
print(f"    Dice: {np.mean(exp3_dice_no_tta):.4f} +/- {np.std(exp3_dice_no_tta):.4f}")
print("  Com TTA:")
print(f"    IoU:  {np.mean(exp3_iou_tta):.4f} +/- {np.std(exp3_iou_tta):.4f}")
print(f"    Dice: {np.mean(exp3_dice_tta):.4f} +/- {np.std(exp3_dice_tta):.4f}")

print("\n" + "="*80)
print("                              MELHORIAS")
print("="*80)
baseline_dice = np.mean(all_dice)
exp2_dice = np.mean(all_dice_tta)
exp3_dice = np.mean(exp3_dice_tta)

print(f"\nMelhoria Exp2 vs Baseline: {(exp2_dice - baseline_dice)*100:+.2f}% Dice")
print(f"Melhoria Exp3 vs Baseline: {(exp3_dice - baseline_dice)*100:+.2f}% Dice")
print(f"Melhoria Exp3 vs Exp2:     {(exp3_dice - exp2_dice)*100:+.2f}% Dice")

---
# COMPARAÇÃO FINAL DOS EXPERIMENTOS

In [None]:
fig, axes = plt.subplots(1, 3, figsize=(18, 5))

axes[0].plot(exp3_history['train_loss'], label='Treino')
axes[0].plot(exp3_history['val_loss'], label='Validação')
axes[0].set_title('Exp3 - Loss')
axes[0].legend()

axes[1].plot(exp3_history['train_iou'], label='Treino')
axes[1].plot(exp3_history['val_iou'], label='Validação')
axes[1].set_title('Exp3 - IoU')
axes[1].legend()

axes[2].plot(exp3_history['train_dice'], label='Treino')
axes[2].plot(exp3_history['val_dice'], label='Validação')
axes[2].set_title('Exp3 - Dice Score')
axes[2].legend()

for ax in axes:
    ax.set_xlabel('Época')
    ax.grid(True, alpha=0.3)

plt.suptitle('Experimento 3: Attention U-Net', fontsize=14)
plt.tight_layout()
plt.show()

## Exp3.7 - Gráficos de Treinamento Experimento 3

In [None]:
# Carregar melhor modelo Exp3
exp3_model.load_state_dict(torch.load('best_exp3_attention_unet.pth'))
exp3_model.eval()

# Criar TTA para Exp3
tta_exp3 = TestTimeAugmentation(exp3_model, device)

# Avaliação
exp3_iou_no_tta = []
exp3_dice_no_tta = []
exp3_iou_tta = []
exp3_dice_tta = []

print("Avaliando Experimento 3 (com e sem TTA)...")

with torch.no_grad():
    for images, masks in tqdm(exp3_val_loader, desc='Avaliando Exp3'):
        images, masks = images.to(device), masks.to(device)
        
        for i in range(images.shape[0]):
            img = images[i:i+1]
            mask = masks[i:i+1]
            
            # Sem TTA
            pred = torch.sigmoid(exp3_model(img))
            pred_bin = (pred > 0.5).float()
            
            intersection = (pred_bin * mask).sum()
            union = pred_bin.sum() + mask.sum() - intersection
            iou = (intersection + 1e-6) / (union + 1e-6)
            dice = (2 * intersection + 1e-6) / (pred_bin.sum() + mask.sum() + 1e-6)
            
            exp3_iou_no_tta.append(iou.item())
            exp3_dice_no_tta.append(dice.item())
            
            # Com TTA
            pred_tta = tta_exp3(img)
            pred_bin_tta = (pred_tta > 0.5).float()
            
            intersection = (pred_bin_tta * mask).sum()
            union = pred_bin_tta.sum() + mask.sum() - intersection
            iou_tta = (intersection + 1e-6) / (union + 1e-6)
            dice_tta = (2 * intersection + 1e-6) / (pred_bin_tta.sum() + mask.sum() + 1e-6)
            
            exp3_iou_tta.append(iou_tta.item())
            exp3_dice_tta.append(dice_tta.item())

print('\n' + '='*60)
print('=== RESULTADOS EXPERIMENTO 3 (Attention U-Net) ===')
print('='*60)
print('\nSem TTA:')
print(f'  IoU  - Média: {np.mean(exp3_iou_no_tta):.4f} | Std: {np.std(exp3_iou_no_tta):.4f}')
print(f'  Dice - Média: {np.mean(exp3_dice_no_tta):.4f} | Std: {np.std(exp3_dice_no_tta):.4f}')
print('\nCom TTA:')
print(f'  IoU  - Média: {np.mean(exp3_iou_tta):.4f} | Std: {np.std(exp3_iou_tta):.4f}')
print(f'  Dice - Média: {np.mean(exp3_dice_tta):.4f} | Std: {np.std(exp3_dice_tta):.4f}')

## Exp3.6 - Avaliação Final do Experimento 3

In [None]:
# Loop de treinamento Experimento 3
exp3_history = {'train_loss': [], 'val_loss': [], 'train_iou': [], 'val_iou': [], 
                'train_dice': [], 'val_dice': []}
exp3_best_dice = 0

print("="*60)
print("INICIANDO TREINAMENTO - EXPERIMENTO 3 (Attention U-Net)")
print("="*60)

for epoch in range(EXP3_NUM_EPOCHS):
    print(f'\nEpoch {epoch+1}/{EXP3_NUM_EPOCHS}')
    
    train_loss, train_iou, train_dice = train_epoch_exp3(
        exp3_model, exp3_train_loader, exp3_criterion, exp3_optimizer, exp3_scheduler
    )
    val_loss, val_iou, val_dice = validate_exp3(
        exp3_model, exp3_val_loader, exp3_criterion
    )
    
    exp3_history['train_loss'].append(train_loss)
    exp3_history['val_loss'].append(val_loss)
    exp3_history['train_iou'].append(train_iou)
    exp3_history['val_iou'].append(val_iou)
    exp3_history['train_dice'].append(train_dice)
    exp3_history['val_dice'].append(val_dice)
    
    current_lr = exp3_optimizer.param_groups[1]['lr']
    print(f'Train - Loss: {train_loss:.4f} | IoU: {train_iou:.4f} | Dice: {train_dice:.4f}')
    print(f'Val   - Loss: {val_loss:.4f} | IoU: {val_iou:.4f} | Dice: {val_dice:.4f}')
    print(f'LR (decoder): {current_lr:.2e}')
    
    if val_dice > exp3_best_dice:
        exp3_best_dice = val_dice
        torch.save(exp3_model.state_dict(), 'best_exp3_attention_unet.pth')
        print(f'*** Modelo Exp3 salvo! Dice: {exp3_best_dice:.4f} ***')

print("\n" + "="*60)
print(f"EXPERIMENTO 3 CONCLUÍDO - Melhor Dice: {exp3_best_dice:.4f}")
print("="*60)

In [None]:
def train_epoch_exp3(model, loader, criterion, optimizer, scheduler):
    """Treino do Experimento 3 com OneCycleLR"""
    model.train()
    total_loss = 0
    total_iou = 0
    total_dice = 0
    
    for images, masks in tqdm(loader, desc='Train Exp3'):
        images, masks = images.to(device), masks.to(device)
        
        optimizer.zero_grad()
        outputs = model(images)
        loss = criterion(outputs, masks)
        loss.backward()
        
        # Gradient clipping
        torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=1.0)
        
        optimizer.step()
        scheduler.step()  # OneCycleLR atualiza a cada batch
        
        total_loss += loss.item()
        iou, dice = calc_metrics(outputs, masks)
        total_iou += iou
        total_dice += dice
    
    n = len(loader)
    return total_loss/n, total_iou/n, total_dice/n

@torch.no_grad()
def validate_exp3(model, loader, criterion):
    """Validação do Experimento 3"""
    model.eval()
    total_loss = 0
    total_iou = 0
    total_dice = 0
    
    for images, masks in tqdm(loader, desc='Val Exp3'):
        images, masks = images.to(device), masks.to(device)
        outputs = model(images)
        loss = criterion(outputs, masks)
        
        total_loss += loss.item()
        iou, dice = calc_metrics(outputs, masks)
        total_iou += iou
        total_dice += dice
    
    n = len(loader)
    return total_loss/n, total_iou/n, total_dice/n

print("Funções de treino Exp3 definidas!")

In [None]:
# Loss e otimizador para Exp3
exp3_dice_loss = smp.losses.DiceLoss(mode='binary')
exp3_bce_loss = smp.losses.SoftBCEWithLogitsLoss()
# Adicionar Focal Loss para lidar com desbalanceamento
exp3_focal_loss = smp.losses.FocalLoss(mode='binary', alpha=0.25, gamma=2.0)

def exp3_criterion(pred, target):
    """Loss combinada: BCE + Dice + Focal"""
    return 0.4 * exp3_bce_loss(pred, target) + 0.4 * exp3_dice_loss(pred, target) + 0.2 * exp3_focal_loss(pred, target)

# Otimizador com weight decay diferenciado
encoder_params = list(exp3_model.encoder.parameters())
decoder_params = [p for n, p in exp3_model.named_parameters() if 'encoder' not in n]

exp3_optimizer = optim.AdamW([
    {'params': encoder_params, 'lr': EXP3_LEARNING_RATE * 0.1},  # Encoder: LR menor
    {'params': decoder_params, 'lr': EXP3_LEARNING_RATE}         # Decoder: LR normal
], weight_decay=1e-4)

exp3_scheduler = optim.lr_scheduler.OneCycleLR(
    exp3_optimizer,
    max_lr=[EXP3_LEARNING_RATE * 0.1, EXP3_LEARNING_RATE],
    epochs=EXP3_NUM_EPOCHS,
    steps_per_epoch=len(exp3_train_loader),
    pct_start=0.1,
    anneal_strategy='cos'
)

print("Otimizador e Loss do Exp3 configurados!")

## Exp3.5 - Treinamento do Experimento 3

In [None]:
# Usar os mesmos transforms do Exp2 (com CLAHE) para comparação justa
exp3_train_dataset = OpticDiscDataset(train_pairs, get_exp2_train_transforms())
exp3_val_dataset = OpticDiscDataset(val_pairs, get_exp2_val_transforms())

exp3_train_loader = DataLoader(exp3_train_dataset, batch_size=EXP3_BATCH_SIZE, 
                                shuffle=True, num_workers=2, pin_memory=True)
exp3_val_loader = DataLoader(exp3_val_dataset, batch_size=EXP3_BATCH_SIZE, 
                              shuffle=False, num_workers=2, pin_memory=True)

# Criar modelo Attention U-Net
exp3_model = AttentionUNet(
    encoder_name=EXP3_ENCODER,
    encoder_weights=EXP3_ENCODER_WEIGHTS,
    in_channels=3,
    classes=1
).to(device)

# Contar parâmetros
def count_parameters(model):
    return sum(p.numel() for p in model.parameters() if p.requires_grad)

print(f'Exp3 - Modelo: Attention U-Net')
print(f'Exp3 - Parâmetros treináveis: {count_parameters(exp3_model):,}')

## Exp3.4 - Dataset e Modelo do Experimento 3

In [None]:
class AttentionUNet(nn.Module):
    """
    Attention U-Net com:
    - Encoder ResNet50 pré-treinado
    - Attention Gates nas skip connections
    - SE blocks no decoder
    - ASPP no bottleneck
    """
    def __init__(self, encoder_name='resnet50', encoder_weights='imagenet', 
                 in_channels=3, classes=1):
        super().__init__()
        
        # Encoder pré-treinado
        self.encoder = smp.encoders.get_encoder(
            encoder_name,
            in_channels=in_channels,
            depth=5,
            weights=encoder_weights
        )
        
        # Canais do encoder ResNet50: [3, 64, 256, 512, 1024, 2048]
        encoder_channels = self.encoder.out_channels
        
        # ASPP no bottleneck
        self.aspp = ASPP(encoder_channels[-1], encoder_channels[-1] // 2)
        
        # Decoder channels
        decoder_channels = [256, 128, 64, 32, 16]
        
        # Attention Gates
        self.attention_gates = nn.ModuleList([
            AttentionGate(decoder_channels[0], encoder_channels[-2], decoder_channels[0] // 2),  # 1024 -> 256
            AttentionGate(decoder_channels[1], encoder_channels[-3], decoder_channels[1] // 2),  # 512 -> 128
            AttentionGate(decoder_channels[2], encoder_channels[-4], decoder_channels[2] // 2),  # 256 -> 64
            AttentionGate(decoder_channels[3], encoder_channels[-5], decoder_channels[3] // 2),  # 64 -> 32
        ])
        
        # Decoder blocks com SE
        self.decoder_blocks = nn.ModuleList()
        
        # Block 1: ASPP output + skip4
        in_ch = encoder_channels[-1] // 2 + encoder_channels[-2]  # aspp + skip
        self.decoder_blocks.append(self._make_decoder_block(in_ch, decoder_channels[0]))
        
        # Block 2-4
        for i in range(1, 4):
            in_ch = decoder_channels[i-1] + encoder_channels[-2-i]
            self.decoder_blocks.append(self._make_decoder_block(in_ch, decoder_channels[i]))
        
        # Block 5 (sem skip)
        self.decoder_blocks.append(self._make_decoder_block(decoder_channels[3], decoder_channels[4]))
        
        # Segmentation head
        self.segmentation_head = nn.Conv2d(decoder_channels[-1], classes, kernel_size=1)
    
    def _make_decoder_block(self, in_channels, out_channels):
        """Cria um bloco do decoder com SE block"""
        return nn.Sequential(
            nn.Conv2d(in_channels, out_channels, kernel_size=3, padding=1, bias=False),
            nn.BatchNorm2d(out_channels),
            nn.ReLU(inplace=True),
            nn.Conv2d(out_channels, out_channels, kernel_size=3, padding=1, bias=False),
            nn.BatchNorm2d(out_channels),
            nn.ReLU(inplace=True),
            SqueezeExcitation(out_channels, reduction=16)
        )
    
    def forward(self, x):
        # Encoder
        features = self.encoder(x)
        # features: [input, stage1, stage2, stage3, stage4, stage5]
        # Para ResNet50: [3, 64, 256, 512, 1024, 2048]
        
        # ASPP no bottleneck
        x = self.aspp(features[-1])
        
        # Decoder com Attention Gates
        skips = features[:-1][::-1]  # Inverter ordem dos skips
        
        for i in range(4):
            # Upsample
            x = nn.functional.interpolate(x, scale_factor=2, mode='bilinear', align_corners=False)
            
            # Attention Gate na skip connection
            skip = self.attention_gates[i](x, skips[i])
            
            # Concatenar
            x = torch.cat([x, skip], dim=1)
            
            # Decoder block
            x = self.decoder_blocks[i](x)
        
        # Último upsample (sem skip)
        x = nn.functional.interpolate(x, scale_factor=2, mode='bilinear', align_corners=False)
        x = self.decoder_blocks[4](x)
        
        # Segmentation head
        x = self.segmentation_head(x)
        
        return x

print("Attention U-Net definida!")

## Exp3.3 - Attention U-Net Completa

In [None]:
class AttentionGate(nn.Module):
    """
    Attention Gate: Permite ao modelo focar nas regiões relevantes
    nas skip connections, suprimindo respostas irrelevantes.
    
    Referência: "Attention U-Net: Learning Where to Look for the Pancreas"
    """
    def __init__(self, F_g, F_l, F_int):
        """
        F_g: número de canais do gating signal (do decoder)
        F_l: número de canais do skip connection (do encoder)
        F_int: número de canais intermediários
        """
        super().__init__()
        
        self.W_g = nn.Sequential(
            nn.Conv2d(F_g, F_int, kernel_size=1, stride=1, padding=0, bias=True),
            nn.BatchNorm2d(F_int)
        )
        
        self.W_x = nn.Sequential(
            nn.Conv2d(F_l, F_int, kernel_size=1, stride=1, padding=0, bias=True),
            nn.BatchNorm2d(F_int)
        )
        
        self.psi = nn.Sequential(
            nn.Conv2d(F_int, 1, kernel_size=1, stride=1, padding=0, bias=True),
            nn.BatchNorm2d(1),
            nn.Sigmoid()
        )
        
        self.relu = nn.ReLU(inplace=True)
    
    def forward(self, g, x):
        """
        g: gating signal do decoder (menor resolução)
        x: skip connection do encoder (maior resolução)
        """
        # Redimensionar g para o tamanho de x se necessário
        if g.shape[2:] != x.shape[2:]:
            g = nn.functional.interpolate(g, size=x.shape[2:], mode='bilinear', align_corners=False)
        
        g1 = self.W_g(g)
        x1 = self.W_x(x)
        
        psi = self.relu(g1 + x1)
        psi = self.psi(psi)
        
        return x * psi


class SqueezeExcitation(nn.Module):
    """
    Squeeze-and-Excitation Block: Recalibração adaptativa de canais
    Aprende a importância relativa de cada canal de features
    
    Referência: "Squeeze-and-Excitation Networks"
    """
    def __init__(self, channels, reduction=16):
        super().__init__()
        
        self.squeeze = nn.AdaptiveAvgPool2d(1)
        self.excitation = nn.Sequential(
            nn.Linear(channels, channels // reduction, bias=False),
            nn.ReLU(inplace=True),
            nn.Linear(channels // reduction, channels, bias=False),
            nn.Sigmoid()
        )
    
    def forward(self, x):
        b, c, _, _ = x.shape
        
        # Squeeze: Global Average Pooling
        y = self.squeeze(x).view(b, c)
        
        # Excitation: FC layers
        y = self.excitation(y).view(b, c, 1, 1)
        
        # Scale
        return x * y.expand_as(x)


class ASPP(nn.Module):
    """
    Atrous Spatial Pyramid Pooling: Captura contexto multi-escala
    usando convoluções dilatadas com diferentes taxas de dilatação
    
    Referência: "DeepLab: Semantic Image Segmentation"
    """
    def __init__(self, in_channels, out_channels, rates=[6, 12, 18]):
        super().__init__()
        
        # 1x1 convolution
        self.conv1x1 = nn.Sequential(
            nn.Conv2d(in_channels, out_channels, 1, bias=False),
            nn.BatchNorm2d(out_channels),
            nn.ReLU(inplace=True)
        )
        
        # Atrous convolutions com diferentes rates
        self.atrous_convs = nn.ModuleList()
        for rate in rates:
            self.atrous_convs.append(
                nn.Sequential(
                    nn.Conv2d(in_channels, out_channels, 3, padding=rate, dilation=rate, bias=False),
                    nn.BatchNorm2d(out_channels),
                    nn.ReLU(inplace=True)
                )
            )
        
        # Global Average Pooling
        self.global_pool = nn.Sequential(
            nn.AdaptiveAvgPool2d(1),
            nn.Conv2d(in_channels, out_channels, 1, bias=False),
            nn.BatchNorm2d(out_channels),
            nn.ReLU(inplace=True)
        )
        
        # Projeção final
        num_features = out_channels * (2 + len(rates))  # 1x1 + atrous + global
        self.project = nn.Sequential(
            nn.Conv2d(num_features, out_channels, 1, bias=False),
            nn.BatchNorm2d(out_channels),
            nn.ReLU(inplace=True),
            nn.Dropout(0.5)
        )
    
    def forward(self, x):
        size = x.shape[2:]
        
        features = [self.conv1x1(x)]
        
        for atrous_conv in self.atrous_convs:
            features.append(atrous_conv(x))
        
        # Global pooling feature
        global_feat = self.global_pool(x)
        global_feat = nn.functional.interpolate(global_feat, size=size, mode='bilinear', align_corners=False)
        features.append(global_feat)
        
        # Concatenar e projetar
        x = torch.cat(features, dim=1)
        x = self.project(x)
        
        return x

print("Módulos de Atenção definidos!")

## Exp3.2 - Módulos de Atenção (Attention Gates)

In [None]:
# ============================================
# EXPERIMENTO 3 - Configurações
# ============================================

EXP3_BATCH_SIZE = 8
EXP3_NUM_EPOCHS = 50
EXP3_LEARNING_RATE = 1e-4
EXP3_IMG_SIZE = 512

EXP3_ENCODER = 'resnet50'
EXP3_ENCODER_WEIGHTS = 'imagenet'

print("=== Experimento 3: Attention U-Net com Modificação na Arquitetura ===")

## Exp3.1 - Configurações do Experimento 3

---
# EXPERIMENTO 3: Attention U-Net (Modificação na Arquitetura)

**Modificação na arquitetura da rede neural:**
1. **Attention Gates**: Mecanismos de atenção nas skip connections que permitem ao modelo focar nas regiões relevantes
2. **Squeeze-and-Excitation (SE) Blocks**: Recalibração adaptativa dos canais de features
3. **ASPP (Atrous Spatial Pyramid Pooling)**: Captura de contexto multi-escala no bottleneck

Esta é uma modificação substancial na topologia da rede, não apenas um aumento de largura/profundidade.

In [None]:
fig, axes = plt.subplots(1, 3, figsize=(18, 5))

axes[0].plot(exp2_history['train_loss'], label='Treino')
axes[0].plot(exp2_history['val_loss'], label='Validação')
axes[0].set_title('Exp2 - Loss')
axes[0].legend()

axes[1].plot(exp2_history['train_iou'], label='Treino')
axes[1].plot(exp2_history['val_iou'], label='Validação')
axes[1].set_title('Exp2 - IoU')
axes[1].legend()

axes[2].plot(exp2_history['train_dice'], label='Treino')
axes[2].plot(exp2_history['val_dice'], label='Validação')
axes[2].set_title('Exp2 - Dice Score')
axes[2].legend()

for ax in axes:
    ax.set_xlabel('Época')
    ax.grid(True, alpha=0.3)

plt.suptitle('Experimento 2: Pré-processamento + Data Aug + Deep Supervision', fontsize=14)
plt.tight_layout()
plt.show()

## Exp2.9 - Gráficos de Treinamento Experimento 2

In [None]:
# Carregar melhor modelo e avaliar com TTA
exp2_model.load_state_dict(torch.load('best_exp2_model.pth'))
exp2_model.eval()

# Criar TTA
tta = TestTimeAugmentation(exp2_model, device)

# Avaliação sem TTA
all_iou_no_tta = []
all_dice_no_tta = []

# Avaliação com TTA
all_iou_tta = []
all_dice_tta = []

print("Avaliando Experimento 2 (com e sem TTA)...")

with torch.no_grad():
    for images, masks in tqdm(exp2_val_loader, desc='Avaliando'):
        images, masks = images.to(device), masks.to(device)
        
        for i in range(images.shape[0]):
            img = images[i:i+1]
            mask = masks[i:i+1]
            
            # Sem TTA
            pred_no_tta = torch.sigmoid(exp2_model(img))
            pred_bin_no_tta = (pred_no_tta > 0.5).float()
            
            intersection = (pred_bin_no_tta * mask).sum()
            union = pred_bin_no_tta.sum() + mask.sum() - intersection
            iou_no_tta = (intersection + 1e-6) / (union + 1e-6)
            dice_no_tta = (2 * intersection + 1e-6) / (pred_bin_no_tta.sum() + mask.sum() + 1e-6)
            
            all_iou_no_tta.append(iou_no_tta.item())
            all_dice_no_tta.append(dice_no_tta.item())
            
            # Com TTA
            pred_tta = tta(img)
            pred_bin_tta = (pred_tta > 0.5).float()
            
            intersection = (pred_bin_tta * mask).sum()
            union = pred_bin_tta.sum() + mask.sum() - intersection
            iou_tta = (intersection + 1e-6) / (union + 1e-6)
            dice_tta = (2 * intersection + 1e-6) / (pred_bin_tta.sum() + mask.sum() + 1e-6)
            
            all_iou_tta.append(iou_tta.item())
            all_dice_tta.append(dice_tta.item())

print('\n' + '='*60)
print('=== RESULTADOS EXPERIMENTO 2 ===')
print('='*60)
print('\nSem TTA:')
print(f'  IoU  - Média: {np.mean(all_iou_no_tta):.4f} | Std: {np.std(all_iou_no_tta):.4f}')
print(f'  Dice - Média: {np.mean(all_dice_no_tta):.4f} | Std: {np.std(all_dice_no_tta):.4f}')
print('\nCom TTA (7 augmentations):')
print(f'  IoU  - Média: {np.mean(all_iou_tta):.4f} | Std: {np.std(all_iou_tta):.4f}')
print(f'  Dice - Média: {np.mean(all_dice_tta):.4f} | Std: {np.std(all_dice_tta):.4f}')
print('\nMelhoria com TTA:')
print(f'  IoU:  +{(np.mean(all_iou_tta) - np.mean(all_iou_no_tta))*100:.2f}%')
print(f'  Dice: +{(np.mean(all_dice_tta) - np.mean(all_dice_no_tta))*100:.2f}%')

## Exp2.8 - Avaliação Final com TTA

In [None]:
# Loop de treinamento Experimento 2
exp2_history = {'train_loss': [], 'val_loss': [], 'train_iou': [], 'val_iou': [], 
                'train_dice': [], 'val_dice': []}
exp2_best_dice = 0

print("="*60)
print("INICIANDO TREINAMENTO - EXPERIMENTO 2")
print("="*60)

for epoch in range(EXP2_NUM_EPOCHS):
    print(f'\nEpoch {epoch+1}/{EXP2_NUM_EPOCHS}')
    
    train_loss, train_iou, train_dice = train_epoch_exp2(
        exp2_model, exp2_train_loader, exp2_criterion, exp2_optimizer
    )
    val_loss, val_iou, val_dice = validate_exp2(
        exp2_model, exp2_val_loader, exp2_criterion
    )
    exp2_scheduler.step()
    
    exp2_history['train_loss'].append(train_loss)
    exp2_history['val_loss'].append(val_loss)
    exp2_history['train_iou'].append(train_iou)
    exp2_history['val_iou'].append(val_iou)
    exp2_history['train_dice'].append(train_dice)
    exp2_history['val_dice'].append(val_dice)
    
    print(f'Train - Loss: {train_loss:.4f} | IoU: {train_iou:.4f} | Dice: {train_dice:.4f}')
    print(f'Val   - Loss: {val_loss:.4f} | IoU: {val_iou:.4f} | Dice: {val_dice:.4f}')
    print(f'LR: {exp2_scheduler.get_last_lr()[0]:.2e}')
    
    if val_dice > exp2_best_dice:
        exp2_best_dice = val_dice
        torch.save(exp2_model.state_dict(), 'best_exp2_model.pth')
        print(f'*** Modelo Exp2 salvo! Dice: {exp2_best_dice:.4f} ***')

print("\n" + "="*60)
print(f"EXPERIMENTO 2 CONCLUÍDO - Melhor Dice: {exp2_best_dice:.4f}")
print("="*60)

In [None]:
def train_epoch_exp2(model, loader, criterion, optimizer):
    """Treino com Deep Supervision"""
    model.train()
    total_loss = 0
    total_iou = 0
    total_dice = 0
    
    for images, masks in tqdm(loader, desc='Train Exp2'):
        images, masks = images.to(device), masks.to(device)
        
        optimizer.zero_grad()
        outputs = model(images)
        
        # Loss com deep supervision
        loss, main_output = deep_supervision_loss(outputs, masks, criterion)
        loss.backward()
        
        # Gradient clipping para estabilidade
        torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=1.0)
        
        optimizer.step()
        
        total_loss += loss.item()
        iou, dice = calc_metrics(main_output, masks)
        total_iou += iou
        total_dice += dice
    
    n = len(loader)
    return total_loss/n, total_iou/n, total_dice/n

@torch.no_grad()
def validate_exp2(model, loader, criterion):
    """Validação do Experimento 2"""
    model.eval()
    total_loss = 0
    total_iou = 0
    total_dice = 0
    
    for images, masks in tqdm(loader, desc='Val Exp2'):
        images, masks = images.to(device), masks.to(device)
        outputs = model(images)
        
        # No eval, deep supervision retorna apenas saída principal
        loss = criterion(outputs, masks)
        
        total_loss += loss.item()
        iou, dice = calc_metrics(outputs, masks)
        total_iou += iou
        total_dice += dice
    
    n = len(loader)
    return total_loss/n, total_iou/n, total_dice/n

print("Funções de treino Exp2 definidas!")

In [None]:
# Criar modelo com Deep Supervision
exp2_model = UNetWithDeepSupervision(
    encoder_name=EXP2_ENCODER,
    encoder_weights=EXP2_ENCODER_WEIGHTS,
    in_channels=3,
    classes=1
).to(device)

# Loss e otimizador
exp2_dice_loss = smp.losses.DiceLoss(mode='binary')
exp2_bce_loss = smp.losses.SoftBCEWithLogitsLoss()

def exp2_criterion(pred, target):
    return 0.5 * exp2_bce_loss(pred, target) + 0.5 * exp2_dice_loss(pred, target)

exp2_optimizer = optim.AdamW(exp2_model.parameters(), lr=EXP2_LEARNING_RATE, weight_decay=1e-4)
exp2_scheduler = optim.lr_scheduler.CosineAnnealingWarmRestarts(
    exp2_optimizer, T_0=10, T_mult=2, eta_min=1e-6
)

print(f'Exp2 - Modelo: U-Net com Deep Supervision')
print(f'Exp2 - Encoder: {EXP2_ENCODER}')

## Exp2.7 - Modelo e Treinamento do Experimento 2

In [None]:
def compare_clahe_effect(pairs, idx=0):
    """Compara imagem original vs com CLAHE"""
    pair = pairs[idx]
    
    # Carregar imagem original
    img_original = np.array(Image.open(pair['image']).convert('RGB'))
    
    # Aplicar CLAHE
    img_clahe = apply_clahe_preprocessing(img_original)
    
    fig, axes = plt.subplots(1, 2, figsize=(14, 6))
    
    axes[0].imshow(img_original)
    axes[0].set_title('Imagem Original')
    axes[0].axis('off')
    
    axes[1].imshow(img_clahe)
    axes[1].set_title('Com CLAHE (Contraste Realçado)')
    axes[1].axis('off')
    
    plt.suptitle('Efeito do Pré-processamento CLAHE', fontsize=14)
    plt.tight_layout()
    plt.show()

# Visualizar efeito CLAHE
compare_clahe_effect(pairs, 0)

## Exp2.6 - Visualizar efeito do pré-processamento CLAHE

In [None]:
# Criar datasets com novos transforms
exp2_train_dataset = OpticDiscDataset(train_pairs, get_exp2_train_transforms())
exp2_val_dataset = OpticDiscDataset(val_pairs, get_exp2_val_transforms())

exp2_train_loader = DataLoader(exp2_train_dataset, batch_size=EXP2_BATCH_SIZE, 
                                shuffle=True, num_workers=2, pin_memory=True)
exp2_val_loader = DataLoader(exp2_val_dataset, batch_size=EXP2_BATCH_SIZE, 
                              shuffle=False, num_workers=2, pin_memory=True)

print(f'Exp2 - Treino: {len(exp2_train_dataset)} | Validação: {len(exp2_val_dataset)}')

## Exp2.5 - Dataset e DataLoaders do Experimento 2

In [None]:
class TestTimeAugmentation:
    """
    Test Time Augmentation: faz múltiplas predições com diferentes 
    augmentations e combina os resultados para predição mais robusta
    """
    def __init__(self, model, device):
        self.model = model
        self.device = device
        
    def __call__(self, image):
        """
        image: tensor normalizado (1, C, H, W)
        retorna: média das predições (1, 1, H, W)
        """
        self.model.eval()
        predictions = []
        
        with torch.no_grad():
            # Original
            pred = torch.sigmoid(self.model(image))
            predictions.append(pred)
            
            # Flip horizontal
            flipped_h = torch.flip(image, dims=[3])
            pred_h = torch.sigmoid(self.model(flipped_h))
            pred_h = torch.flip(pred_h, dims=[3])
            predictions.append(pred_h)
            
            # Flip vertical
            flipped_v = torch.flip(image, dims=[2])
            pred_v = torch.sigmoid(self.model(flipped_v))
            pred_v = torch.flip(pred_v, dims=[2])
            predictions.append(pred_v)
            
            # Flip ambos
            flipped_hv = torch.flip(image, dims=[2, 3])
            pred_hv = torch.sigmoid(self.model(flipped_hv))
            pred_hv = torch.flip(pred_hv, dims=[2, 3])
            predictions.append(pred_hv)
            
            # Rotação 90°
            rotated_90 = torch.rot90(image, k=1, dims=[2, 3])
            pred_90 = torch.sigmoid(self.model(rotated_90))
            pred_90 = torch.rot90(pred_90, k=-1, dims=[2, 3])
            predictions.append(pred_90)
            
            # Rotação 180°
            rotated_180 = torch.rot90(image, k=2, dims=[2, 3])
            pred_180 = torch.sigmoid(self.model(rotated_180))
            pred_180 = torch.rot90(pred_180, k=-2, dims=[2, 3])
            predictions.append(pred_180)
            
            # Rotação 270°
            rotated_270 = torch.rot90(image, k=3, dims=[2, 3])
            pred_270 = torch.sigmoid(self.model(rotated_270))
            pred_270 = torch.rot90(pred_270, k=-3, dims=[2, 3])
            predictions.append(pred_270)
        
        # Média de todas as predições
        avg_prediction = torch.stack(predictions).mean(dim=0)
        return avg_prediction

print("TTA definido!")

## Exp2.4 - Test Time Augmentation (TTA)

In [None]:
class UNetWithDeepSupervision(nn.Module):
    """
    U-Net com Deep Supervision: adiciona saídas auxiliares em diferentes 
    escalas do decoder para melhorar o fluxo de gradiente e convergência
    """
    def __init__(self, encoder_name='resnet50', encoder_weights='imagenet', 
                 in_channels=3, classes=1):
        super().__init__()
        
        # Modelo base U-Net
        self.base_model = smp.Unet(
            encoder_name=encoder_name,
            encoder_weights=encoder_weights,
            in_channels=in_channels,
            classes=classes,
            activation=None,
            decoder_channels=(256, 128, 64, 32, 16)
        )
        
        # Cabeças de Deep Supervision para diferentes escalas
        # Saídas auxiliares nos estágios intermediários do decoder
        self.ds_head_1 = nn.Conv2d(32, classes, kernel_size=1)   # 1/4 da resolução
        self.ds_head_2 = nn.Conv2d(64, classes, kernel_size=1)   # 1/8 da resolução
        self.ds_head_3 = nn.Conv2d(128, classes, kernel_size=1)  # 1/16 da resolução
        
        self.training_mode = True
    
    def forward(self, x):
        # Encoder
        features = self.base_model.encoder(x)
        
        # Decoder com captura de features intermediários
        decoder_output = features[0]  # features mais profundo
        
        decoder_blocks = self.base_model.decoder.blocks
        
        # Armazenar outputs intermediários
        intermediate_outputs = []
        
        # Processar cada bloco do decoder
        for i, block in enumerate(decoder_blocks):
            skip = features[i + 1] if i + 1 < len(features) else None
            decoder_output = block(decoder_output, skip)
            
            # Capturar outputs para deep supervision (nos estágios 2, 3, 4)
            if i == 2:  # 1/8 resolução
                intermediate_outputs.append(decoder_output)
            elif i == 3:  # 1/4 resolução
                intermediate_outputs.append(decoder_output)
        
        # Saída principal
        main_output = self.base_model.segmentation_head(decoder_output)
        
        if self.training_mode and self.training:
            # Saídas auxiliares (deep supervision)
            ds_out_2 = self.ds_head_2(intermediate_outputs[0])  # 1/8
            ds_out_1 = self.ds_head_1(intermediate_outputs[1])  # 1/4
            
            # Redimensionar para o tamanho original
            target_size = main_output.shape[2:]
            ds_out_1 = nn.functional.interpolate(ds_out_1, size=target_size, mode='bilinear', align_corners=False)
            ds_out_2 = nn.functional.interpolate(ds_out_2, size=target_size, mode='bilinear', align_corners=False)
            
            return main_output, ds_out_1, ds_out_2
        
        return main_output

def deep_supervision_loss(outputs, target, criterion, weights=[1.0, 0.4, 0.2]):
    """
    Calcula loss combinada para deep supervision
    weights: pesos para [saída principal, ds1, ds2]
    """
    if isinstance(outputs, tuple):
        main_out, ds1, ds2 = outputs
        loss = weights[0] * criterion(main_out, target)
        loss += weights[1] * criterion(ds1, target)
        loss += weights[2] * criterion(ds2, target)
        return loss, main_out
    else:
        return criterion(outputs, target), outputs

print("Deep Supervision definido!")

## Exp2.3 - Deep Supervision (Múltiplas Saídas)

In [None]:
import cv2

def apply_clahe_preprocessing(image):
    """
    Aplica CLAHE (Contrast Limited Adaptive Histogram Equalization) 
    no canal de luminância para realçar estruturas em imagens de fundo de olho
    """
    # Converter para LAB color space
    lab = cv2.cvtColor(image, cv2.COLOR_RGB2LAB)
    l, a, b = cv2.split(lab)
    
    # Aplicar CLAHE no canal L (luminância)
    clahe = cv2.createCLAHE(clipLimit=3.0, tileGridSize=(8, 8))
    l_clahe = clahe.apply(l)
    
    # Recombinar canais
    lab_clahe = cv2.merge([l_clahe, a, b])
    
    # Converter de volta para RGB
    result = cv2.cvtColor(lab_clahe, cv2.COLOR_LAB2RGB)
    return result

def get_exp2_train_transforms():
    """
    Data Augmentation avançado para imagens médicas:
    - CLAHE como pré-processamento
    - ElasticTransform para simular deformações anatômicas
    - GridDistortion e OpticalDistortion
    - CoarseDropout para regularização
    """
    return A.Compose([
        # Pré-processamento CLAHE (sempre aplicado)
        A.Lambda(image=apply_clahe_preprocessing),
        
        A.Resize(EXP2_IMG_SIZE, EXP2_IMG_SIZE),
        
        # Augmentations geométricos
        A.HorizontalFlip(p=0.5),
        A.VerticalFlip(p=0.5),
        A.RandomRotate90(p=0.5),
        A.ShiftScaleRotate(shift_limit=0.1, scale_limit=0.2, rotate_limit=45, p=0.5),
        
        # Deformações específicas para imagens médicas
        A.OneOf([
            A.ElasticTransform(alpha=120, sigma=120 * 0.05, p=1.0),
            A.GridDistortion(num_steps=5, distort_limit=0.3, p=1.0),
            A.OpticalDistortion(distort_limit=0.5, shift_limit=0.5, p=1.0),
        ], p=0.4),
        
        # Augmentations de cor/intensidade
        A.OneOf([
            A.GaussNoise(var_limit=(10, 50)),
            A.GaussianBlur(blur_limit=3),
            A.MedianBlur(blur_limit=3),
            A.MotionBlur(blur_limit=3),
        ], p=0.3),
        
        A.OneOf([
            A.RandomBrightnessContrast(brightness_limit=0.3, contrast_limit=0.3),
            A.CLAHE(clip_limit=4),
            A.HueSaturationValue(hue_shift_limit=20, sat_shift_limit=30, val_shift_limit=20),
            A.RandomGamma(gamma_limit=(80, 120)),
        ], p=0.4),
        
        # Regularização: CoarseDropout (simula oclusões)
        A.CoarseDropout(max_holes=8, max_height=32, max_width=32, 
                        min_holes=1, min_height=8, min_width=8,
                        fill_value=0, p=0.3),
        
        A.Normalize(mean=(0.485, 0.456, 0.406), std=(0.229, 0.224, 0.225)),
        ToTensorV2(),
    ])

def get_exp2_val_transforms():
    """Transforms de validação com CLAHE"""
    return A.Compose([
        A.Lambda(image=apply_clahe_preprocessing),
        A.Resize(EXP2_IMG_SIZE, EXP2_IMG_SIZE),
        A.Normalize(mean=(0.485, 0.456, 0.406), std=(0.229, 0.224, 0.225)),
        ToTensorV2(),
    ])

print("Transforms do Experimento 2 definidos!")

## Exp2.2 - Pré-processamento CLAHE e Data Augmentation Avançado

## Exp2.1 - Configurações do Experimento 2

# Segmentação do Disco Óptico

Usando U-Net com ResNet pré-treinada para imagens médicas (segmentation_models_pytorch)

## 1. Instalação e Importações

In [20]:
# Instalar bibliotecas necessárias
!pip install segmentation-models-pytorch albumentations -q

In [21]:
import os
import numpy as np
import matplotlib.pyplot as plt
from PIL import Image
from sklearn.model_selection import train_test_split
from tqdm import tqdm

import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import Dataset, DataLoader

import albumentations as A
from albumentations.pytorch import ToTensorV2

import segmentation_models_pytorch as smp
from skimage.draw import polygon

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(f'Dispositivo: {device}')

Dispositivo: cuda


## 2. Configuração

In [None]:
# ==== OPÇÃO 1: Montar Google Drive ====
# (Se falhar, use a Opção 2 abaixo)
try:
    from google.colab import drive
    drive.mount('/content/drive', force_remount=True)
    ROOT_DIR = '/content/drive/MyDrive/PapilaDB/'
    print("Drive montado!")
except:
    print("Falha ao montar Drive. Use a Opção 2.")
    ROOT_DIR = '/content/PapilaDB/'

# ==== OPÇÃO 2: Upload direto via ZIP ====
# 1. Compacte a pasta PapilaDB em um arquivo .zip
# 2. Descomente e execute as linhas abaixo:

# from google.colab import files
# import zipfile
# uploaded = files.upload()  # Selecione o arquivo PapilaDB.zip
# with zipfile.ZipFile('PapilaDB.zip', 'r') as zip_ref:
#     zip_ref.extractall('/content/')
# ROOT_DIR = '/content/PapilaDB/'

# Verificar
import os
print(f"ROOT_DIR: {ROOT_DIR}")
print(f"Existe: {os.path.exists(ROOT_DIR)}")
if os.path.exists(ROOT_DIR):
    print(f"Conteúdo: {os.listdir(ROOT_DIR)}")

BATCH_SIZE = 8
NUM_EPOCHS = 50
LEARNING_RATE = 1e-4
IMG_SIZE = 512

ENCODER = 'resnet50'
ENCODER_WEIGHTS = 'imagenet'

## 3. Preparar Dados

In [None]:
img_dir = ROOT_DIR + 'FundusImages/'
contour_dir = ROOT_DIR + 'ExpertsSegmentations/Contours/'

img_files = sorted(os.listdir(img_dir))
contour_files = sorted(os.listdir(contour_dir))

# Filtrar contornos de disco
disc_contours = [f for f in contour_files if 'disc' in f.lower()]

print(f'Imagens: {len(img_files)}')
print(f'Contornos disco: {len(disc_contours)}')

FileNotFoundError: [Errno 2] No such file or directory: '/home/weslley/Code/Glaucoma Diagnostico/PapilaDB/FundusImages/'

In [None]:
# Criar pares imagem-contorno
def get_pairs():
    pairs = []
    for img_file in img_files:
        img_id = os.path.splitext(img_file)[0]
        for cont in disc_contours:
            if img_id in cont:
                pairs.append({
                    'image': os.path.join(img_dir, img_file),
                    'contour': os.path.join(contour_dir, cont)
                })
                break
    return pairs

pairs = get_pairs()
print(f'Pares encontrados: {len(pairs)}')

## 4. Data Augmentation (Albumentations)

In [None]:
def get_train_transforms():
    return A.Compose([
        A.Resize(IMG_SIZE, IMG_SIZE),
        A.HorizontalFlip(p=0.5),
        A.VerticalFlip(p=0.5),
        A.RandomRotate90(p=0.5),
        A.ShiftScaleRotate(shift_limit=0.1, scale_limit=0.2, rotate_limit=30, p=0.5),
        A.OneOf([
            A.GaussNoise(var_limit=(10, 50)),
            A.GaussianBlur(blur_limit=3),
            A.MedianBlur(blur_limit=3),
        ], p=0.3),
        A.OneOf([
            A.RandomBrightnessContrast(brightness_limit=0.2, contrast_limit=0.2),
            A.CLAHE(clip_limit=2),
            A.HueSaturationValue(hue_shift_limit=10, sat_shift_limit=20, val_shift_limit=10),
        ], p=0.3),
        A.Normalize(mean=(0.485, 0.456, 0.406), std=(0.229, 0.224, 0.225)),
        ToTensorV2(),
    ])

def get_val_transforms():
    return A.Compose([
        A.Resize(IMG_SIZE, IMG_SIZE),
        A.Normalize(mean=(0.485, 0.456, 0.406), std=(0.229, 0.224, 0.225)),
        ToTensorV2(),
    ])

## 5. Dataset

In [None]:
class OpticDiscDataset(Dataset):
    def __init__(self, pairs, transforms=None):
        self.pairs = pairs
        self.transforms = transforms
    
    def __len__(self):
        return len(self.pairs)
    
    def __getitem__(self, idx):
        pair = self.pairs[idx]
        
        # Carregar imagem
        image = np.array(Image.open(pair['image']).convert('RGB'))
        h, w = image.shape[:2]
        
        # Criar máscara do contorno
        contour = np.loadtxt(pair['contour'])
        mask = np.zeros((h, w), dtype=np.uint8)
        
        rr, cc = polygon(contour[:, 1], contour[:, 0], mask.shape)
        mask[rr, cc] = 1
        
        # Aplicar transformações
        if self.transforms:
            transformed = self.transforms(image=image, mask=mask)
            image = transformed['image']
            mask = transformed['mask']
        
        return image, mask.float().unsqueeze(0)

In [None]:
# Split dados
train_pairs, val_pairs = train_test_split(pairs, test_size=0.2, random_state=42)

train_dataset = OpticDiscDataset(train_pairs, get_train_transforms())
val_dataset = OpticDiscDataset(val_pairs, get_val_transforms())

train_loader = DataLoader(train_dataset, batch_size=BATCH_SIZE, shuffle=True, num_workers=2)
val_loader = DataLoader(val_dataset, batch_size=BATCH_SIZE, shuffle=False, num_workers=2)

print(f'Treino: {len(train_dataset)} | Validação: {len(val_dataset)}')

## 6. Visualizar Amostras

In [None]:
def show_sample(dataset, idx=0):
    img, mask = dataset[idx]
    
    # Desnormalizar
    img_np = img.numpy().transpose(1, 2, 0)
    img_np = img_np * np.array([0.229, 0.224, 0.225]) + np.array([0.485, 0.456, 0.406])
    img_np = np.clip(img_np, 0, 1)
    
    fig, ax = plt.subplots(1, 3, figsize=(15, 5))
    ax[0].imshow(img_np)
    ax[0].set_title('Imagem')
    ax[1].imshow(mask.squeeze(), cmap='gray')
    ax[1].set_title('Máscara')
    
    # Overlay
    overlay = img_np.copy()
    m = mask.squeeze().numpy()
    overlay[m > 0.5] = overlay[m > 0.5] * 0.5 + np.array([0, 1, 0]) * 0.5
    ax[2].imshow(overlay)
    ax[2].set_title('Overlay')
    
    for a in ax: a.axis('off')
    plt.tight_layout()
    plt.show()

show_sample(train_dataset, 0)

## 7. Modelo U-Net com ResNet (SMP)

In [None]:
# Criar modelo usando segmentation_models_pytorch
model = smp.Unet(
    encoder_name=ENCODER,
    encoder_weights=ENCODER_WEIGHTS,
    in_channels=3,
    classes=1,
    activation=None  # Usaremos sigmoid na loss
)

model = model.to(device)
print(f'Modelo: U-Net com encoder {ENCODER}')
print(f'Pesos: {ENCODER_WEIGHTS}')

In [None]:
# Alternativa: Usar outros modelos do SMP
# model = smp.DeepLabV3Plus(encoder_name=ENCODER, encoder_weights=ENCODER_WEIGHTS, classes=1)
# model = smp.FPN(encoder_name=ENCODER, encoder_weights=ENCODER_WEIGHTS, classes=1)
# model = smp.PSPNet(encoder_name=ENCODER, encoder_weights=ENCODER_WEIGHTS, classes=1)

## 8. Loss e Métricas

In [None]:
# Loss combinada do SMP
dice_loss = smp.losses.DiceLoss(mode='binary')
bce_loss = smp.losses.SoftBCEWithLogitsLoss()

def criterion(pred, target):
    return 0.5 * bce_loss(pred, target) + 0.5 * dice_loss(pred, target)

# Métricas
def calc_metrics(pred, target, threshold=0.5):
    pred = torch.sigmoid(pred)
    pred_bin = (pred > threshold).float()
    
    # IoU
    intersection = (pred_bin * target).sum()
    union = pred_bin.sum() + target.sum() - intersection
    iou = (intersection + 1e-6) / (union + 1e-6)
    
    # Dice
    dice = (2 * intersection + 1e-6) / (pred_bin.sum() + target.sum() + 1e-6)
    
    return iou.item(), dice.item()

In [None]:
optimizer = optim.AdamW(model.parameters(), lr=LEARNING_RATE, weight_decay=1e-4)
scheduler = optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=NUM_EPOCHS)

## 9. Funções de Treino

In [None]:
def train_epoch(model, loader, criterion, optimizer):
    model.train()
    total_loss = 0
    total_iou = 0
    total_dice = 0
    
    for images, masks in tqdm(loader, desc='Train'):
        images, masks = images.to(device), masks.to(device)
        
        optimizer.zero_grad()
        outputs = model(images)
        loss = criterion(outputs, masks)
        loss.backward()
        optimizer.step()
        
        total_loss += loss.item()
        iou, dice = calc_metrics(outputs, masks)
        total_iou += iou
        total_dice += dice
    
    n = len(loader)
    return total_loss/n, total_iou/n, total_dice/n


@torch.no_grad()
def validate(model, loader, criterion):
    model.eval()
    total_loss = 0
    total_iou = 0
    total_dice = 0
    
    for images, masks in tqdm(loader, desc='Val'):
        images, masks = images.to(device), masks.to(device)
        outputs = model(images)
        loss = criterion(outputs, masks)
        
        total_loss += loss.item()
        iou, dice = calc_metrics(outputs, masks)
        total_iou += iou
        total_dice += dice
    
    n = len(loader)
    return total_loss/n, total_iou/n, total_dice/n

## 10. Treinamento

In [None]:
history = {'train_loss': [], 'val_loss': [], 'train_iou': [], 'val_iou': [], 
           'train_dice': [], 'val_dice': []}
best_dice = 0

for epoch in range(NUM_EPOCHS):
    print(f'\nEpoch {epoch+1}/{NUM_EPOCHS}')
    
    train_loss, train_iou, train_dice = train_epoch(model, train_loader, criterion, optimizer)
    val_loss, val_iou, val_dice = validate(model, val_loader, criterion)
    scheduler.step()
    
    history['train_loss'].append(train_loss)
    history['val_loss'].append(val_loss)
    history['train_iou'].append(train_iou)
    history['val_iou'].append(val_iou)
    history['train_dice'].append(train_dice)
    history['val_dice'].append(val_dice)
    
    print(f'Train - Loss: {train_loss:.4f} | IoU: {train_iou:.4f} | Dice: {train_dice:.4f}')
    print(f'Val   - Loss: {val_loss:.4f} | IoU: {val_iou:.4f} | Dice: {val_dice:.4f}')
    
    if val_dice > best_dice:
        best_dice = val_dice
        torch.save(model.state_dict(), 'best_optic_disc_model.pth')
        print(f'*** Modelo salvo! Dice: {best_dice:.4f} ***')

## 11. Gráficos de Treinamento

In [None]:
fig, axes = plt.subplots(1, 3, figsize=(18, 5))

axes[0].plot(history['train_loss'], label='Treino')
axes[0].plot(history['val_loss'], label='Validação')
axes[0].set_title('Loss')
axes[0].legend()

axes[1].plot(history['train_iou'], label='Treino')
axes[1].plot(history['val_iou'], label='Validação')
axes[1].set_title('IoU')
axes[1].legend()

axes[2].plot(history['train_dice'], label='Treino')
axes[2].plot(history['val_dice'], label='Validação')
axes[2].set_title('Dice Score')
axes[2].legend()

for ax in axes:
    ax.set_xlabel('Época')
    ax.grid(True, alpha=0.3)

plt.tight_layout()
plt.show()

## 12. Visualizar Predições

In [None]:
# Carregar melhor modelo
model.load_state_dict(torch.load('best_optic_disc_model.pth'))
model.eval()

def predict_and_show(dataset, indices):
    fig, axes = plt.subplots(len(indices), 4, figsize=(20, 5*len(indices)))
    
    for i, idx in enumerate(indices):
        img, mask = dataset[idx]
        
        with torch.no_grad():
            pred = model(img.unsqueeze(0).to(device))
            pred = torch.sigmoid(pred).cpu().squeeze().numpy()
        
        # Desnormalizar imagem
        img_np = img.numpy().transpose(1, 2, 0)
        img_np = img_np * np.array([0.229, 0.224, 0.225]) + np.array([0.485, 0.456, 0.406])
        img_np = np.clip(img_np, 0, 1)
        
        mask_np = mask.squeeze().numpy()
        pred_bin = (pred > 0.5).astype(np.float32)
        
        # Overlay
        overlay = img_np.copy()
        overlay[pred_bin > 0.5] = overlay[pred_bin > 0.5] * 0.5 + np.array([0, 1, 0]) * 0.5
        
        axes[i, 0].imshow(img_np)
        axes[i, 0].set_title('Imagem')
        axes[i, 1].imshow(mask_np, cmap='gray')
        axes[i, 1].set_title('Ground Truth')
        axes[i, 2].imshow(pred_bin, cmap='gray')
        axes[i, 2].set_title('Predição')
        axes[i, 3].imshow(overlay)
        axes[i, 3].set_title('Overlay')
        
        for ax in axes[i]: ax.axis('off')
    
    plt.tight_layout()
    plt.show()

predict_and_show(val_dataset, [0, 1, 2, 3])

## 13. Inferência em Nova Imagem

In [None]:
def segment_image(image_path, model, img_size=512):
    """Segmenta o disco óptico em uma nova imagem"""
    
    # Carregar e preprocessar
    image = np.array(Image.open(image_path).convert('RGB'))
    original_size = image.shape[:2]
    
    transform = get_val_transforms()
    transformed = transform(image=image)
    img_tensor = transformed['image'].unsqueeze(0).to(device)
    
    # Predição
    model.eval()
    with torch.no_grad():
        pred = model(img_tensor)
        pred = torch.sigmoid(pred).cpu().squeeze().numpy()
    
    # Redimensionar máscara para tamanho original
    pred_resized = np.array(Image.fromarray((pred * 255).astype(np.uint8)).resize(
        (original_size[1], original_size[0]), Image.BILINEAR)) / 255.0
    
    return pred_resized

# Exemplo de uso
# mask = segment_image('caminho/para/imagem.jpg', model)

## 14. Avaliação Final

In [None]:
# Calcular métricas no conjunto de validação
model.load_state_dict(torch.load('best_optic_disc_model.pth'))
model.eval()

all_iou = []
all_dice = []

with torch.no_grad():
    for images, masks in val_loader:
        images, masks = images.to(device), masks.to(device)
        outputs = model(images)
        
        for i in range(outputs.shape[0]):
            iou, dice = calc_metrics(outputs[i:i+1], masks[i:i+1])
            all_iou.append(iou)
            all_dice.append(dice)

print('=== Resultados no Conjunto de Validação ===')
print(f'IoU  - Média: {np.mean(all_iou):.4f} | Std: {np.std(all_iou):.4f}')
print(f'Dice - Média: {np.mean(all_dice):.4f} | Std: {np.std(all_dice):.4f}')

---
# EXPERIMENTO 2: Melhorias de Pré-processamento, Data Augmentation e Deep Supervision

**Melhorias implementadas:**
1. **Pré-processamento avançado**: CLAHE (Contrast Limited Adaptive Histogram Equalization) para realçar estruturas em imagens de fundo de olho
2. **Data Augmentation específico para imagens médicas**: ElasticTransform, GridDistortion, OpticalDistortion
3. **Deep Supervision**: Múltiplas saídas em diferentes escalas para melhor gradiente
4. **Test Time Augmentation (TTA)**: Ensemble de predições com diferentes augmentations

## Exp2.1 - Configurações do Experimento 2

In [None]:
# ============================================
# EXPERIMENTO 2 - Configurações
# ============================================

# Mesmas configurações base
EXP2_BATCH_SIZE = 8
EXP2_NUM_EPOCHS = 50
EXP2_LEARNING_RATE = 1e-4
EXP2_IMG_SIZE = 512

EXP2_ENCODER = 'resnet50'
EXP2_ENCODER_WEIGHTS = 'imagenet'

print("=== Experimento 2: Melhorias de Pré-processamento e Data Augmentation ===")