In [None]:
import torch
import torch.nn as nn
import torchvision.transforms as transforms
from torch.utils.data import DataLoader, Dataset
import itertools
import numpy as np
import matplotlib
matplotlib.use('Agg')
import matplotlib.pyplot as plt
plt.ioff()
from PIL import Image
import os
from pathlib import Path
import time
from torch.nn.utils import spectral_norm
from datetime import datetime
import json

# ================================
# RTX5070
# ================================

In [None]:
# Mixed Precision - simples
try:
    from torch.amp import GradScaler, autocast  # ‚Üê NOVA SINTAXE PRIMEIRO
    AMP_AVAILABLE = True
    print("‚úÖ Mixed Precision dispon√≠vel (torch.amp)")
except ImportError:
    try:
        from torch.cuda.amp import GradScaler, autocast  # ‚Üê FALLBACK
        AMP_AVAILABLE = True
        print("‚úÖ Mixed Precision dispon√≠vel (cuda.amp)")
    except ImportError:
        # Fallback simples
        class GradScaler:
            def __init__(self, device='cuda'): pass
            def scale(self, loss): return loss
            def step(self, optimizer): optimizer.step()
            def update(self): pass
        
        class autocast:
            def __init__(self, device='cuda'): pass
            def __enter__(self): return self
            def __exit__(self, *args): pass
        
        AMP_AVAILABLE = False
        print("‚ö†Ô∏è Mixed Precision n√£o dispon√≠vel - usando fallback")

In [None]:
try:
    from torchvision.utils import save_image, make_grid
    print("‚úÖ TorchVision utils dispon√≠vel")
except ImportError:
    # Fallback matplotlib
    def save_image(tensor, path, nrow=8, normalize=False, value_range=None, **kwargs):
        if tensor.dim() == 4:
            tensor = tensor[0]
        img = tensor.cpu().detach().permute(1, 2, 0).numpy()
        if normalize and value_range:
            img = (img - value_range[0]) / (value_range[1] - value_range[0])
        img = np.clip(img, 0, 1)
        plt.imsave(path, img)
    
    def make_grid(tensor, nrow=8, **kwargs):
        return tensor[0] if tensor.dim() == 4 else tensor
    
    print("‚ö†Ô∏è TorchVision n√£o dispon√≠vel - usando matplotlib")

print(f"üöÄ PyTorch: {torch.__version__}")
if torch.cuda.is_available():
    print(f"üéØ GPU: {torch.cuda.get_device_name(0)}")
    print(f"üíæ VRAM: {torch.cuda.get_device_properties(0).total_memory / 1e9:.1f}GB")

# ================================
# CONFIGURA√á√ïES
# ================================

In [None]:
class ConfigTest:
    """Teste r√°pido - configura√ß√£o equilibrada"""
    
    # Caminhos
    DATASET_PATH = "dataset/real2cartoon"
    MODEL_SAVE_PATH = "models_complex/cyclegan_test"
    
    # Par√¢metros b√°sicos
    IMG_SIZE = 256
    BATCH_SIZE = 1  
    LR = 0.0001
    BETA1 = 0.5
    BETA2 = 0.999
    
    # Loss weights
    LAMBDA_CYCLE = 20.0
    LAMBDA_IDENTITY = 1
    
    # Treinamento
    NUM_EPOCHS = 2
    DECAY_EPOCH = 1
    
    # Monitoramento
    LOG_FREQ = 50           # N√£o muito frequente
    SAMPLE_FREQ = 200       # Samples ocasionais
    SAVE_FREQ = 1
    
    # Performance equilibrada
    NUM_WORKERS = 2        # Moderado
    USE_AMP = AMP_AVAILABLE # Usa se dispon√≠vel
    NUM_RESIDUAL_BLOCKS = 9
    USE_TORCH_COMPILE = False
    
    DEVICE = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

In [None]:
class ConfigProduction:
    
    DATASET_PATH = "dataset/real2cartoon"
    MODEL_SAVE_PATH = "models_complex/cyclegan_perfection"
    
    IMG_SIZE = 256
    BATCH_SIZE = 1  # REDUZIR para 1
    
    # LEARNING RATES BALANCEADOS
    LR_G = 1e-4      # Generator mais lento
    LR_D = 4e-4      # Discriminators 4x mais r√°pidos
    
    BETA1 = 0.0      # CR√çTICO para TTUR
    BETA2 = 0.9     
    
    # PESOS REBALANCEADOS
    LAMBDA_CYCLE = 50.0     # LSGAN precisa pesos maiores
    LAMBDA_IDENTITY = 25.0  # Metade do cycle
    
    NUM_EPOCHS = 50
    DECAY_EPOCH = 25
    
    LOG_FREQ = 10
    SAMPLE_FREQ = 200
    SAVE_FREQ = 5
    
    NUM_WORKERS = 0  # Reduzir para debug
    USE_AMP = False  # DESABILITAR AMP
    NUM_RESIDUAL_BLOCKS = 6
    
    DEVICE = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

In [None]:
#config = ConfigTest()  # Use ConfigProduction() para treinamento completo
config = ConfigProduction()  # Use esta linha para produ√ß√£o

In [None]:
def apply_cuda_optimizations():
    """Aplicar otimiza√ß√µes CUDA seguras"""
    try:
        # Otimiza√ß√µes b√°sicas
        torch.backends.cudnn.deterministic = False
        torch.backends.cudnn.benchmark = True
        torch.backends.cuda.matmul.allow_tf32 = True
        torch.backends.cudnn.allow_tf32 = True
        print("‚úÖ Otimiza√ß√µes CUDA aplicadas com seguran√ßa")
        
    except Exception as e:
        print(f"‚ö†Ô∏è Erro ao aplicar otimiza√ß√µes: {e}")

apply_cuda_optimizations()

In [None]:
def check_gpu_memory():
    if torch.cuda.is_available():
        total_memory = torch.cuda.get_device_properties(0).total_memory / 1e9
        print(f"üîç GPU: {torch.cuda.get_device_name(0)}")
        print(f"üíæ VRAM Total: {total_memory:.1f}GB")
        print(f"üìä Batch configurado: {config.BATCH_SIZE}")
        
        # Estimativa de uso de VRAM
        estimated_vram = config.BATCH_SIZE * 256 * 256 * 3 * 4 * 8 / 1e9  # Rough estimate
        vram_percent = (estimated_vram / total_memory) * 100
        
        print(f"üìà VRAM estimado: {estimated_vram:.1f}GB ({vram_percent:.1f}%)")
        
        if vram_percent < 80:
            print("‚úÖ Configura√ß√£o segura para VRAM")
        elif vram_percent < 90:
            print("‚ö†Ô∏è Configura√ß√£o no limite - monitore a VRAM")
        else:
            print("‚ùå Configura√ß√£o pode causar out-of-memory")
            
    else:
        print("‚ùå CUDA n√£o dispon√≠vel")

check_gpu_memory()

# ================================
# DATASET PERSONALIZADO
# ================================

In [None]:
class CycleGANDataset(Dataset):
    def __init__(self, root_path, mode='train', transform=None):
        self.transform = transform
        
        # Caminhos para dom√≠nios A (real) e B (cartoon)
        if mode == 'train':
            self.path_A = Path(root_path) / 'trainA'
            self.path_B = Path(root_path) / 'trainB'
        else:
            self.path_A = Path(root_path) / 'testA'
            self.path_B = Path(root_path) / 'testB'
        
        # Listar todas as imagens
        self.images_A = sorted(list(self.path_A.glob('*.jpg')) + list(self.path_A.glob('*.png')))
        self.images_B = sorted(list(self.path_B.glob('*.jpg')) + list(self.path_B.glob('*.png')))
        
        self.length = max(len(self.images_A), len(self.images_B))
    
    def __len__(self):
        return self.length
    
    def __getitem__(self, idx):
        # Usar m√≥dulo para evitar index out of bounds
        img_A_path = self.images_A[idx % len(self.images_A)]
        img_B_path = self.images_B[idx % len(self.images_B)]
        
        img_A = Image.open(img_A_path).convert('RGB')
        img_B = Image.open(img_B_path).convert('RGB')
        
        if self.transform:
            img_A = self.transform(img_A)
            img_B = self.transform(img_B)
        
        return {'A': img_A, 'B': img_B}

# ================================
# TRANSFORMA√á√ïES
# ================================

In [None]:
transform = transforms.Compose([
    transforms.Resize((config.IMG_SIZE, config.IMG_SIZE)),
    transforms.ToTensor(),
    transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))  # [-1, 1]
])


# ================================
# BLOCOS DE CONSTRU√á√ÉO
# ================================

In [None]:
class ResidualBlock(nn.Module):
    def __init__(self, channels):
        super(ResidualBlock, self).__init__()
        self.block = nn.Sequential(
            nn.ReflectionPad2d(1),
            nn.Conv2d(channels, channels, 3),
            nn.InstanceNorm2d(channels),
            nn.ReLU(inplace=True),
            nn.ReflectionPad2d(1),
            nn.Conv2d(channels, channels, 3),
            nn.InstanceNorm2d(channels)
        )
    
    def forward(self, x):
        return x + self.block(x)

# ================================
# GERADOR (ResNet-based)
# ================================

In [None]:
class Generator(nn.Module):
    def __init__(self, input_channels=3, output_channels=3, num_residual_blocks=9):
        super(Generator, self).__init__()

        if num_residual_blocks is None:
            num_residual_blocks = config.NUM_RESIDUAL_BLOCKS
        
        # Encoder
        model = [
            nn.ReflectionPad2d(3),
            nn.Conv2d(input_channels, 64, 7),
            nn.InstanceNorm2d(64),
            nn.ReLU(inplace=True)
        ]
        
        # Downsampling
        in_features = 64
        out_features = in_features * 2
        for _ in range(2):
            model += [
                nn.Conv2d(in_features, out_features, 3, stride=2, padding=1),
                nn.InstanceNorm2d(out_features),
                nn.ReLU(inplace=True)
            ]
            in_features = out_features
            out_features = in_features * 2
        
        # Residual blocks
        for _ in range(num_residual_blocks):
            model += [ResidualBlock(in_features)]
        
        # Upsampling
        out_features = in_features // 2
        for _ in range(2):
            model += [
                nn.Upsample(scale_factor=2, mode='bilinear', align_corners=False),
                nn.ReflectionPad2d(1),
                nn.Conv2d(in_features, out_features, kernel_size=3, stride=1, padding=0),
                nn.InstanceNorm2d(out_features),
                nn.ReLU(inplace=True)
            ]
            in_features = out_features
            out_features = in_features // 2
        
        # Output layer
        model += [
            nn.ReflectionPad2d(3),
            nn.Conv2d(64, output_channels, 7),
            nn.Tanh()
        ]
        
        self.model = nn.Sequential(*model)
    
    def forward(self, x):
        return self.model(x)

# ================================
# DISCRIMINADOR (PatchGAN)
# ================================

In [None]:
class Discriminator(nn.Module):
    def __init__(self, input_channels=3):
        super(Discriminator, self).__init__()
        
        def discriminator_block(in_filters, out_filters, normalize=True):
            layers = []
            conv = spectral_norm(nn.Conv2d(in_filters, out_filters, 4, stride=2, padding=1))
            layers.append(conv)
            
            if normalize:
                layers.append(nn.InstanceNorm2d(out_filters))
            layers.append(nn.LeakyReLU(0.2, inplace=True))
            return layers
        
        self.model = nn.Sequential(
            *discriminator_block(input_channels, 64, normalize=False),
            *discriminator_block(64, 128),
            *discriminator_block(128, 256),
            *discriminator_block(256, 512),
            nn.AdaptiveAvgPool2d(1),
            nn.Flatten(),
            spectral_norm(nn.Linear(512, 1))  # ‚Üê S√ì UMA LAYER
        )
            
    def forward(self, img):
        return self.model(img)

# ================================
# INICIALIZA√á√ÉO DOS MODELOS
# ================================

In [None]:
def weights_init_normal(m):
    """Inicializa√ß√£o de pesos simplificada"""
    classname = m.__class__.__name__
    if classname.find('Conv') != -1:
        try:
            torch.nn.init.normal_(m.weight.data, 0.0, 0.01)
        except:
            pass
    elif classname.find('BatchNorm2d') != -1 or classname.find('InstanceNorm2d') != -1:
        try:
            torch.nn.init.normal_(m.weight.data, 1.0, 0.02)
            torch.nn.init.constant_(m.bias.data, 0.0)
        except:
            pass

In [None]:
# Criar modelos
print("üîß Criando modelos...")
G_AB = Generator(num_residual_blocks=config.NUM_RESIDUAL_BLOCKS).to(config.DEVICE)
G_BA = Generator(num_residual_blocks=config.NUM_RESIDUAL_BLOCKS).to(config.DEVICE)
D_A = Discriminator().to(config.DEVICE)
D_B = Discriminator().to(config.DEVICE)

# Inicializar pesos
print("üîß Inicializando pesos...")
try:
    G_AB.apply(weights_init_normal)
    G_BA.apply(weights_init_normal)
    D_A.apply(weights_init_normal)
    D_B.apply(weights_init_normal)
    print("‚úÖ Pesos inicializados com sucesso")
except Exception as e:
    print(f"‚ö†Ô∏è Inicializa√ß√£o customizada falhou: {e}")
    print("üîÑ Usando inicializa√ß√£o padr√£o PyTorch")

# ================================
# LOSS FUNCTIONS
# ================================

In [None]:
class CycleGANLoss:
    def __init__(self):
        self.bce_loss = nn.BCELoss()
        self.l1_loss = nn.L1Loss()
        self.mse_loss = nn.MSELoss()
    
    def adversarial_loss(self, pred, target_is_real):
        if target_is_real:
            target = torch.ones_like(pred) * 0.95
        else:
            target = torch.zeros_like(pred) + 0.05
        return self.bce_loss(pred, target)
    
    # ‚Üê NOVA FUN√á√ÉO LSGAN (ADICIONAR)
    def adversarial_loss_lsgan(self, pred, target_is_real):
        if target_is_real:
            target = torch.ones_like(pred) * 0.9  # Em vez de 1.0
        else:
            target = torch.zeros_like(pred) + 0.1  # Em vez de 0.0
        return 0.5 * self.mse_loss(pred, target)
    
    def adversarial_loss_smooth(self, pred, target_is_real):
        """Loss mais suave especificamente para D_B super-expert"""
        if target_is_real:
            target = torch.ones_like(pred) * 0.75  # ‚Üê Menos confiante
        else:
            target = torch.zeros_like(pred) + 0.25  # ‚Üê Menos confiante
        return self.bce_loss(pred, target)
    
    def cycle_consistency_loss(self, real, cycled):
        return self.l1_loss(real, cycled)
    
    def identity_loss(self, real, same):
        return self.l1_loss(real, same)

# ================================
# OTIMIZADORES
# ================================

In [None]:
optimizer_G = torch.optim.Adam(
    itertools.chain(G_AB.parameters(), G_BA.parameters()),
    lr=config.LR_G, betas=(config.BETA1, config.BETA2)
)

optimizer_D_A = torch.optim.Adam(D_A.parameters(), lr=config.LR_D, betas=(config.BETA1, config.BETA2))
optimizer_D_B = torch.optim.Adam(D_B.parameters(), lr=config.LR_D, betas=(config.BETA1, config.BETA2))

# Learning rate schedulers
def lambda_rule(epoch):
    lr_l = 1.0 - max(0, epoch + 1 - config.DECAY_EPOCH) / (config.NUM_EPOCHS - config.DECAY_EPOCH + 1)
    return lr_l

scheduler_G = torch.optim.lr_scheduler.LambdaLR(optimizer_G, lr_lambda=lambda_rule)
scheduler_D_A = torch.optim.lr_scheduler.LambdaLR(optimizer_D_A, lr_lambda=lambda_rule)
scheduler_D_B = torch.optim.lr_scheduler.LambdaLR(optimizer_D_B, lr_lambda=lambda_rule)

# ================================
# BUFFER PARA IMAGENS FAKE
# ================================

In [None]:
class ImageBuffer:
    def __init__(self, buffer_size=50):
        self.buffer_size = buffer_size
        self.buffer = []
    
    def query(self, images):
        if self.buffer_size == 0:
            return images
        
        images_cloned = images.detach().clone()
        return_images = []
        
        for image in images_cloned:
            image = torch.unsqueeze(image.data, 0)
            if len(self.buffer) < self.buffer_size:
                self.buffer.append(image.clone())
                return_images.append(image)
            else:
                p = np.random.uniform(0, 1)
                if p > 0.5:
                    i = np.random.randint(0, self.buffer_size)
                    return_images.append(self.buffer[i].clone())
                    self.buffer[i] = image.clone()
                else:
                    return_images.append(image)
        
        return torch.cat(return_images, 0)

fake_A_buffer = ImageBuffer()
fake_B_buffer = ImageBuffer()


In [None]:
if config.USE_AMP:
    try:
        if AMP_AVAILABLE:
            # ARROZ
            try:
                scalers = {
                    'G': GradScaler(), 
                    'D_A': GradScaler(), 
                    'D_B': GradScaler() 
                }
                print("‚úÖ Mixed Precision (nova sintaxe) configurado")
            except:
                # Fallback para sintaxe antiga
                scalers = {
                    'G': GradScaler(),
                    'D_A': GradScaler(),
                    'D_B': GradScaler()
                }
                print("‚úÖ Mixed Precision (sintaxe antiga) configurado")
        else:
            config.USE_AMP = False
            print("‚ùå Mixed Precision desabilitado")
    except Exception as e:
        print(f"‚ö†Ô∏è Erro ao configurar AMP: {e}")
        config.USE_AMP = False


# ================================
# DATASET E DATALOADERS
# ================================

In [None]:
train_dataset = CycleGANDataset(config.DATASET_PATH, mode='train', transform=transform)
test_dataset = CycleGANDataset(config.DATASET_PATH, mode='test', transform=transform)

# DataLoader otimizado mas seguro
train_loader = DataLoader(
    train_dataset, 
    batch_size=config.BATCH_SIZE,
    shuffle=True, 
    num_workers=config.NUM_WORKERS,
    pin_memory=True,               
    #persistent_workers=True,
    #prefetch_factor=2,  # Reduzido para estabilidade
    drop_last=True
)

test_loader = DataLoader(
    test_dataset, 
    batch_size=config.BATCH_SIZE, 
    shuffle=False, 
    num_workers=config.NUM_WORKERS//2,  # Menos workers para test
    pin_memory=True,
    # prefetch_factor=1
)

print(f"Train dataset: {len(train_dataset)} amostras")
print(f"Test dataset: {len(test_dataset)} amostras")

# ================================
# FUN√á√ïES DE VISUALIZA√á√ÉO
# ================================

In [None]:
def save_sample_images(real_A, real_B, fake_A, fake_B, cycle_A, cycle_B, epoch, batch_i):
    """Salvamento otimizado usando torchvision.utils"""
    
    os.makedirs(f"{config.MODEL_SAVE_PATH}/images", exist_ok=True)
    
    try:
        # Criar grid de 6 imagens - muito mais eficiente que matplotlib
        imgs = torch.cat([
            real_A[0:1], fake_B[0:1], cycle_A[0:1],
            real_B[0:1], fake_A[0:1], cycle_B[0:1]
        ], dim=0)
        
        grid = make_grid(imgs, nrow=3, normalize=True, value_range=(-1, 1), padding=2)
        save_path = f"{config.MODEL_SAVE_PATH}/images/epoch_{epoch:03d}_batch_{batch_i:04d}.png"
        save_image(grid, save_path)
        
        # Log apenas no primeiro batch de cada √©poca para n√£o poluir console
        if batch_i == 0:
            print(f"    üíæ Sample salvo: epoch_{epoch:03d}_batch_{batch_i:04d}.png")
            
    except Exception as e:
        # N√£o parar o treinamento por erro de salvamento de imagem
        print(f"    ‚ö†Ô∏è Erro ao salvar imagem: {e}")

In [None]:
def save_checkpoint_robust(epoch, G_AB, G_BA, D_A, D_B, optimizer_G, optimizer_D_A, optimizer_D_B, 
                          scheduler_G, scheduler_D_A, scheduler_D_B, history, best_loss=None):
    """Sistema robusto de checkpoints"""
    
    try:
        os.makedirs(config.MODEL_SAVE_PATH, exist_ok=True)
        
        # Calcular se √© o melhor modelo
        current_loss = history['G_loss'][-1] if history['G_loss'] else float('inf')
        if best_loss is None:
            best_loss = float('inf')
        is_best = current_loss < best_loss
        
        checkpoint_data = {
            'epoch': epoch,
            'timestamp': datetime.now().isoformat(),
            'best_loss': min(best_loss, current_loss),
            'G_AB_state_dict': G_AB.state_dict(),
            'G_BA_state_dict': G_BA.state_dict(),
            'D_A_state_dict': D_A.state_dict(),
            'D_B_state_dict': D_B.state_dict(),
            'optimizer_G_state_dict': optimizer_G.state_dict(),
            'optimizer_D_A_state_dict': optimizer_D_A.state_dict(),
            'optimizer_D_B_state_dict': optimizer_D_B.state_dict(),
            'scheduler_G_state_dict': scheduler_G.state_dict(),
            'scheduler_D_A_state_dict': scheduler_D_A.state_dict(),
            'scheduler_D_B_state_dict': scheduler_D_B.state_dict(),
            'history': history,
            'config_snapshot': {
                'batch_size': config.BATCH_SIZE,
                'lr_g': config.LR_G,
                'lr_d': config.LR_D,
                'lambda_cycle': config.LAMBDA_CYCLE,
                'lambda_identity': config.LAMBDA_IDENTITY,
                'num_epochs': config.NUM_EPOCHS,
                'use_amp': config.USE_AMP
            }
        }
        
        # Salvar checkpoint atual
        checkpoint_path = f"{config.MODEL_SAVE_PATH}/checkpoint_epoch_{epoch:03d}.pth"
        torch.save(checkpoint_data, checkpoint_path)
        
        # Sempre salvar como √∫ltimo
        torch.save(checkpoint_data, f"{config.MODEL_SAVE_PATH}/latest_checkpoint.pth")
        
        # Salvar como melhor se for o caso
        if is_best:
            torch.save(checkpoint_data, f"{config.MODEL_SAVE_PATH}/best_model.pth")
            print(f"    üíé Novo melhor modelo! G Loss: {current_loss:.4f}")
        
        print(f"    üíæ Checkpoint salvo: epoch_{epoch:03d}.pth")
        
        # Limpeza autom√°tica (manter apenas √∫ltimos 3)
        try:
            checkpoint_dir = Path(config.MODEL_SAVE_PATH)
            checkpoints = []
            for f in checkpoint_dir.glob("checkpoint_epoch_*.pth"):
                try:
                    epoch_num = int(f.stem.split('_')[-1])
                    checkpoints.append((epoch_num, f))
                except:
                    continue
            
            checkpoints.sort(key=lambda x: x[0])
            while len(checkpoints) > 3:
                _, old_file = checkpoints.pop(0)
                old_file.unlink()
        except:
            pass
        
        return min(best_loss, current_loss)
        
    except Exception as e:
        print(f"    ‚ö†Ô∏è Erro ao salvar checkpoint: {e}")
        return best_loss if best_loss is not None else float('inf')

In [None]:
def try_load_checkpoint():
    """Tentar carregar checkpoint existente"""
    latest_path = f"{config.MODEL_SAVE_PATH}/latest_checkpoint.pth"
    
    if os.path.exists(latest_path):
        try:
            print(f"üìÇ Checkpoint encontrado: {latest_path}")
            # REMOVER: choice = input(...)
            # REMOVER: if choice in ['y', 'yes']:
        
            checkpoint = torch.load(latest_path, map_location=config.DEVICE)
            print(f"‚úÖ Checkpoint da √©poca {checkpoint['epoch']} carregado automaticamente")
            return checkpoint
            
        except Exception as e:
            print(f"‚ùå Erro ao carregar checkpoint: {e}")
    
    return None

# ================================
# LOOP DE TREINAMENTO
# ================================

In [None]:
def add_discriminator_noise(images):
    random_noise = 0.05 + 0.10 * torch.rand(1).item()
    noise = torch.randn_like(images) * random_noise
    return images + noise

def ultra_stable_train_step(real_A, real_B, models, optimizers, criterion, config, batch_i):
    """Training step com 5 D-steps para ressuscitar discriminadores"""
    
    G_AB, G_BA, D_A, D_B = models['G_AB'], models['G_BA'], models['D_A'], models['D_B']
    opt_G, opt_D_A, opt_D_B = optimizers['G'], optimizers['D_A'], optimizers['D_B']
    
    if batch_i % 2 == 0:  # ‚Üê S√ì TREINAR D a cada 2 batches
        
        # DISCRIMINADOR A
        opt_D_A.zero_grad()
        pred_real_A = D_A(real_A)
        loss_D_real_A = criterion.adversarial_loss_lsgan(pred_real_A, True)  # ‚Üê LSGAN
        
        with torch.no_grad():
            fake_A = G_BA(real_B)
        
        pred_fake_A = D_A(fake_A.detach())
        loss_D_fake_A = criterion.adversarial_loss_lsgan(pred_fake_A, False)
        
        loss_D_A = (loss_D_real_A + loss_D_fake_A) * 0.5
        loss_D_A.backward()
        torch.nn.utils.clip_grad_norm_(D_A.parameters(), max_norm=1.0)
        opt_D_A.step()
        
        # DISCRIMINADOR B
        opt_D_B.zero_grad()
        pred_real_B = D_B(add_discriminator_noise(real_B))
        loss_D_real_B = criterion.adversarial_loss_lsgan(pred_real_B, True)  # ‚Üê LSGAN
        
        with torch.no_grad():
            fake_B = G_AB(real_A)
        
        pred_fake_B = D_B(add_discriminator_noise(fake_B.detach()))
        loss_D_fake_B = criterion.adversarial_loss_lsgan(pred_fake_B, False)
        
        loss_D_B = (loss_D_real_B + loss_D_fake_B) * 0.5
        loss_D_B.backward()
        torch.nn.utils.clip_grad_norm_(D_B.parameters(), max_norm=1.0)
        opt_D_B.step()
    
    else:
        # ‚úÖ QUANDO N√ÉO TREINAR D, AINDA PRECISAMOS DAS PREDI√á√ïES PARA LOGS:
        with torch.no_grad():
            fake_A = G_BA(real_B)
            fake_B = G_AB(real_A)
            pred_real_A = D_A(real_A)
            pred_fake_A = D_A(fake_A)
            pred_real_B = D_B(real_B) 
            pred_fake_B = D_B(fake_B)
            loss_D_A = torch.tensor(0.0)  # ‚Üê Para logs
            loss_D_B = torch.tensor(0.0)  # ‚Üê Para logs

    
    # =================
    # TREINAR GERADORES - APENAS 1 STEP
    # ==================
    
    opt_G.zero_grad()
    
    # Generate fresh fakes
    fake_B = G_AB(real_A)
    fake_A = G_BA(real_B)
    
    # Adversarial loss (peso MUITO reduzido)
    pred_fake_B = D_B(fake_B)
    loss_GAN_AB = criterion.adversarial_loss_lsgan(pred_fake_B, True)  # ‚Üê LSGAN
    
    pred_fake_A = D_A(fake_A)
    loss_GAN_BA = criterion.adversarial_loss_lsgan(pred_fake_A, True)  # ‚Üê LSGAN
    
    loss_GAN = (loss_GAN_AB + loss_GAN_BA) * 1.0  # ‚Üê PESO MUITO BAIXO
    
    # Cycle loss (DOMINANTE)
    cycle_A = G_BA(fake_B)
    cycle_B = G_AB(fake_A)
    
    loss_cycle_A = criterion.cycle_consistency_loss(real_A, cycle_A)
    loss_cycle_B = criterion.cycle_consistency_loss(real_B, cycle_B)
    loss_cycle = (loss_cycle_A + loss_cycle_B) * 0.5
    
    # Identity loss (quase zero)
    if config.LAMBDA_IDENTITY > 0:
        identity_A = G_BA(real_A)
        identity_B = G_AB(real_B)
        loss_identity_A = criterion.identity_loss(real_A, identity_A)
        loss_identity_B = criterion.identity_loss(real_B, identity_B)
        loss_identity = (loss_identity_A + loss_identity_B) * 0.5
    else:
        loss_identity = 0
    
    # Total generator loss - CYCLE DOMINA COMPLETAMENTE
    loss_G = (loss_GAN + 
         config.LAMBDA_CYCLE * loss_cycle + 
         config.LAMBDA_IDENTITY * loss_identity)
    
    loss_G.backward()
    
    # Gradient clipping agressivo para geradores
    torch.nn.utils.clip_grad_norm_(G_AB.parameters(), max_norm=1.0)
    torch.nn.utils.clip_grad_norm_(G_BA.parameters(), max_norm=1.0)
    
    opt_G.step() 

    G_AB.eval()  # ‚Üê MODO EVAL
    G_BA.eval()  # ‚Üê MODO EVAL

    with torch.no_grad():
        display_fake_B = G_AB(real_A)  # ‚Üê Limpo, s√≥ para visualizar
        display_fake_A = G_BA(real_B)  # ‚Üê Limpo, s√≥ para visualizar

    # ADICIONAR ESTAS 2 LINHAS:
    G_AB.train()  # ‚Üê VOLTAR PARA TRAIN
    G_BA.train()  # ‚Üê VOLTAR PARA TRAIN

    
    return {
        'loss_G': loss_G.item(),
        'loss_D_A': loss_D_A.item(),
        'loss_D_B': loss_D_B.item(),
        'loss_cycle': loss_cycle.item(),
        'loss_GAN': loss_GAN.item(),
        'pred_real_A': pred_real_A.mean().item(),
        'pred_fake_A': pred_fake_A.mean().item(),
        'pred_real_B': pred_real_B.mean().item(),
        'pred_fake_B': pred_fake_B.mean().item(),
        'fake_A': display_fake_A, 
        'fake_B': display_fake_B, 
        'cycle_A': cycle_A.detach(), 
        'cycle_B': cycle_B.detach()  
    }

In [None]:
def train_cyclegan():
    print("üöÄ Iniciando treinamento CycleGAN MELHORADO")
    print(f"üìä Configura√ß√£o: Œª_cycle={config.LAMBDA_CYCLE}, Œª_identity={config.LAMBDA_IDENTITY}")
    
    criterion = CycleGANLoss()

    # Criar diret√≥rios
    os.makedirs(config.MODEL_SAVE_PATH, exist_ok=True)
    os.makedirs(f"{config.MODEL_SAVE_PATH}/images", exist_ok=True)
    
    # Tentar carregar checkpoint
    checkpoint = try_load_checkpoint()
    start_epoch = checkpoint['epoch'] + 1 if checkpoint else 0
    
    # Hist√≥rico de losses
    history = {
        'G_loss': [], 'D_A_loss': [], 'D_B_loss': [],
        'cycle_loss': [], 'adv_loss': [], 'epoch_times': []
    }
    best_loss = float('inf')
    
    # Carregar estados do checkpoint se existir
    if checkpoint:
        G_AB.load_state_dict(checkpoint['G_AB_state_dict'])
        G_BA.load_state_dict(checkpoint['G_BA_state_dict'])
        D_A.load_state_dict(checkpoint['D_A_state_dict'])
        D_B.load_state_dict(checkpoint['D_B_state_dict'])
        
        optimizer_G.load_state_dict(checkpoint['optimizer_G_state_dict'])
        optimizer_D_A.load_state_dict(checkpoint['optimizer_D_A_state_dict'])
        optimizer_D_B.load_state_dict(checkpoint['optimizer_D_B_state_dict'])
        
        scheduler_G.load_state_dict(checkpoint['scheduler_G_state_dict'])
        scheduler_D_A.load_state_dict(checkpoint['scheduler_D_A_state_dict'])
        scheduler_D_B.load_state_dict(checkpoint['scheduler_D_B_state_dict'])
        
        history = checkpoint['history']
        best_loss = checkpoint.get('best_loss', float('inf'))
        
        print(f"üìÇ Checkpoint carregado - continuando da √©poca {start_epoch}")
    
    # Mixed Precision
    if config.USE_AMP:
        #ARROZ
        scalers = {
            'G': GradScaler(),
            'D_A': GradScaler(),
            'D_B': GradScaler()
        }
    
    # LOOP PRINCIPAL COM TRATAMENTO DE ERROS
    print(f"\nüé¨ INICIANDO TREINAMENTO - √âpocas {start_epoch} a {config.NUM_EPOCHS-1}")
    
    total_start_time = time.time()
    
    try:
        for epoch in range(start_epoch, config.NUM_EPOCHS):
            epoch_start_time = time.time()
            
            # Losses da √©poca
            epoch_G_loss = 0
            epoch_D_A_loss = 0
            epoch_D_B_loss = 0
            epoch_cycle_loss = 0
            epoch_adv_loss = 0
            
            for batch_i, batch in enumerate(train_loader):
                real_B = batch['B'].to(config.DEVICE, non_blocking=True)
                real_A = batch['A'].to(config.DEVICE, non_blocking=True)

                # Criar dicion√°rios para organizar
                models = {
                    'G_AB': G_AB,
                    'G_BA': G_BA, 
                    'D_A': D_A,
                    'D_B': D_B
                }

                optimizers = {
                    'G': optimizer_G,
                    'D_A': optimizer_D_A,
                    'D_B': optimizer_D_B
                }

                # CHAMAR A FUN√á√ÉO CORRIGIDA
                results = ultra_stable_train_step(
                    real_A, real_B, models, optimizers, criterion, config, batch_i
            )

                # Acumular losses e m√©tricas
                epoch_G_loss += results['loss_G']
                epoch_D_A_loss += results['loss_D_A']
                epoch_D_B_loss += results['loss_D_B']
                epoch_cycle_loss += results['loss_cycle']
                epoch_adv_loss += results['loss_GAN']

                # LOGGING MELHORADO COM STATUS DOS DISCRIMINADORES
                if batch_i % config.LOG_FREQ == 0:
                    progress = ((epoch * len(train_loader) + batch_i) / (config.NUM_EPOCHS * len(train_loader))) * 100

                    # VRAM info
                    if torch.cuda.is_available():
                        vram_used = torch.cuda.memory_allocated() / 1e9
                        vram_total = torch.cuda.get_device_properties(0).total_memory / 1e9
                        vram_percent = (vram_used / vram_total) * 100
                    else:
                        vram_percent = 0

                    # STATUS DOS DISCRIMINADORES - CRUCIAL!
                    d_status = "üî•" if (results['loss_D_A'] > 0.1 and results['loss_D_B'] > 0.1) else "üíÄ"

                    print(f"E{epoch:03d}-B{batch_i:04d} ({progress:.1f}%) {d_status} | "
                          f"G:{results['loss_G']:.4f} Cyc:{results['loss_cycle']:.4f} | "
                          f"D_A:{results['loss_D_A']:.4f} D_B:{results['loss_D_B']:.4f} | "
                          f"Pred: rA:{results['pred_real_A']:.2f} fA:{results['pred_fake_A']:.2f} "
                          f"rB:{results['pred_real_B']:.2f} fB:{results['pred_fake_B']:.2f} | "
                          f"VRAM:{vram_percent:.1f}%")

                # Salvar imagens de amostra
                if batch_i % config.SAMPLE_FREQ == 0:
                    save_sample_images(
                        real_A, real_B, 
                        results['fake_A'], results['fake_B'], 
                        results['cycle_A'], results['cycle_B'], 
                        epoch, batch_i
                    )

            # Atualizar learning rate
            scheduler_G.step()
            scheduler_D_A.step()
            scheduler_D_B.step()
            
            # Fim da √©poca
            epoch_time = time.time() - epoch_start_time
            
            # Salvar m√©dias das losses da √©poca
            history['G_loss'].append(epoch_G_loss / len(train_loader))
            history['D_A_loss'].append(epoch_D_A_loss / len(train_loader))
            history['D_B_loss'].append(epoch_D_B_loss / len(train_loader))
            history['cycle_loss'].append(epoch_cycle_loss / len(train_loader))
            history['adv_loss'].append(epoch_adv_loss / len(train_loader))
            history['epoch_times'].append(epoch_time)
            
            # Log da √©poca
            current_lr = scheduler_G.get_last_lr()[0]
            avg_G_loss = epoch_G_loss / len(train_loader)
            avg_cycle_loss = epoch_cycle_loss / len(train_loader)
            
            print(f"\n‚úÖ √âPOCA {epoch}/{config.NUM_EPOCHS-1} conclu√≠da em {epoch_time:.1f}s")
            print(f"üìä G:{avg_G_loss:.4f} | Cycle:{avg_cycle_loss:.4f} | LR:{current_lr:.6f}")
            
            # Salvar checkpoint com sistema robusto
            if epoch % config.SAVE_FREQ == 0 or epoch == config.NUM_EPOCHS - 1:
                best_loss = save_checkpoint_robust(
                    epoch, G_AB, G_BA, D_A, D_B, 
                    optimizer_G, optimizer_D_A, optimizer_D_B,
                    scheduler_G, scheduler_D_A, scheduler_D_B,
                    history, best_loss
                )
    
    except KeyboardInterrupt:
        print("\n‚ö†Ô∏è Treinamento interrompido pelo usu√°rio")
        save_checkpoint_robust(
            epoch, G_AB, G_BA, D_A, D_B, 
            optimizer_G, optimizer_D_A, optimizer_D_B,
            scheduler_G, scheduler_D_A, scheduler_D_B,
            history, best_loss
        )
        return history
    
    except Exception as e:
        print(f"\n‚ùå Erro durante treinamento: {e}")
        save_checkpoint_robust(
            epoch, G_AB, G_BA, D_A, D_B, 
            optimizer_G, optimizer_D_A, optimizer_D_B,
            scheduler_G, scheduler_D_A, scheduler_D_B,
            history, best_loss
        )
        raise e
    
    # Finaliza√ß√£o
    total_time = time.time() - total_start_time
    total_hours = total_time / 3600
    
    print(f"\nüéâ TREINAMENTO CONCLU√çDO!")
    print(f"‚è±Ô∏è Tempo total: {total_hours:.2f} horas")
    print(f"üìà Melhor Generator Loss: {best_loss:.4f}")
    print(f"üíæ Modelos salvos em: {config.MODEL_SAVE_PATH}")
    
    return history

# ================================
# FUN√á√ÉO PRINCIPAL
# ================================

In [None]:
# Verificar se dataset existe
if not os.path.exists(config.DATASET_PATH):
    print(f"‚ùå Dataset n√£o encontrado em {config.DATASET_PATH}")
    print("üí° Execute primeiro o notebook 'organize_datasets.ipynb'")
else:
    print(f"‚úÖ Dataset encontrado: {config.DATASET_PATH}")
    print(f"üéØ Treinando modelo baseline CycleGAN")
    
    # Iniciar treinamento
    history = train_cyclegan()
    
    # Plotar gr√°ficos de loss
    if history and len(history['G_loss']) > 0:
        # Plotar gr√°ficos de loss
        plt.figure(figsize=(15, 5))

        plt.subplot(1, 3, 1)
        plt.plot(history['G_loss'], label='Generator Loss')
        plt.title('Generator Loss')
        plt.xlabel('Epoch')
        plt.ylabel('Loss')
        plt.legend()

        plt.subplot(1, 3, 2)
        plt.plot(history['D_A_loss'], label='Discriminator A')
        plt.plot(history['D_B_loss'], label='Discriminator B')
        plt.title('Discriminator Losses')
        plt.xlabel('Epoch')
        plt.ylabel('Loss')
        plt.legend()

        plt.subplot(1, 3, 3)
        plt.plot(history['cycle_loss'], label='Cycle Loss')
        plt.plot(history['adv_loss'], label='Adversarial Loss')
        plt.title('Component Losses')
        plt.xlabel('Epoch')
        plt.ylabel('Loss')
        plt.legend()

        plt.tight_layout()
        plt.savefig(f"{config.MODEL_SAVE_PATH}/training_curves.png", dpi=150, bbox_inches='tight')

        print(f"üìä Gr√°ficos salvos em: {config.MODEL_SAVE_PATH}/training_curves.png")
        plt.close()  # ‚Üê Fechar figura para liberar mem√≥ria
    else:
        print("‚ö†Ô∏è Sem dados suficientes para plotar gr√°ficos")

# ================================
# DIAGN√ìSTICO DO PROBLEMA DE VISUALIZA√á√ÉO
# ================================

def test_visualization_issue():
    """Testar se problema de visualiza√ß√£o √© train vs eval mode"""
    print("üîç DIAGN√ìSTICO: Testando diferen√ßa train vs eval mode...")
    
    try:
        # Verificar se modelos existem
        if 'G_AB' not in globals() or 'G_BA' not in globals():
            print("‚ùå Modelos n√£o encontrados. Execute o treinamento primeiro.")
            return
        
        # Pegar um batch pequeno de teste
        if len(test_loader) == 0:
            print("‚ùå Test loader vazio.")
            return
            
        test_batch = next(iter(test_loader))
        real_A = test_batch['A'][:1].to(config.DEVICE)  # S√≥ 1 imagem
        real_B = test_batch['B'][:1].to(config.DEVICE)
        
        print(f"‚úÖ Usando imagem de teste: {real_A.shape}")
        
        # TESTE 1: Gera√ß√£o em modo TRAIN
        G_AB.train()
        G_BA.train()
        with torch.no_grad():
            fake_B_train = G_AB(real_A)
            fake_A_train = G_BA(real_B)
        
        # TESTE 2: Gera√ß√£o em modo EVAL
        G_AB.eval()
        G_BA.eval()
        with torch.no_grad():
            fake_B_eval = G_AB(real_A)
            fake_A_eval = G_BA(real_B)
        
        # Voltar para train mode
        G_AB.train()
        G_BA.train()
        
        # Calcular diferen√ßas
        diff_A2B = torch.abs(fake_B_train - fake_B_eval).mean().item()
        diff_B2A = torch.abs(fake_A_train - fake_A_eval).mean().item()
        
        print(f"\nüìä RESULTADOS DO DIAGN√ìSTICO:")
        print(f"   Diferen√ßa A‚ÜíB (train vs eval): {diff_A2B:.6f}")
        print(f"   Diferen√ßa B‚ÜíA (train vs eval): {diff_B2A:.6f}")
        
        # Interpretar resultados
        threshold = 0.01  # Threshold para considerar "diferen√ßa significativa"
        
        if diff_A2B > threshold or diff_B2A > threshold:
            print(f"\n‚úÖ CONFIRMADO: Problema √© train/eval mode!")
            print(f"   üí° Solu√ß√£o: Usar .eval() mode para visualiza√ß√£o")
            diagnosis = "train_eval_mode"
        else:
            print(f"\n‚ùå N√ÉO √© train/eval mode. Investigar outras causas:")
            print(f"   üí° Poss√≠veis causas: noise, batch norm, dropout, etc.")
            diagnosis = "other_cause"
        
        # Salvar compara√ß√£o visual para an√°lise
        os.makedirs(f"{config.MODEL_SAVE_PATH}/debug", exist_ok=True)
        
        # Grid de compara√ß√£o: [Real | Train_mode | Eval_mode]
        comparison_A2B = torch.cat([
            real_A[0], fake_B_train[0], fake_B_eval[0]
        ], dim=2)  # Horizontal
        
        comparison_B2A = torch.cat([
            real_B[0], fake_A_train[0], fake_A_eval[0]
        ], dim=2)  # Horizontal
        
        # Grid vertical
        full_comparison = torch.cat([comparison_A2B, comparison_B2A], dim=1)
        
        debug_path = f"{config.MODEL_SAVE_PATH}/debug/diagnosis_train_vs_eval.png"
        save_image(full_comparison, debug_path, normalize=True, value_range=(-1, 1))
        
        print(f"\nüíæ Compara√ß√£o visual salva em:")
        print(f"   {debug_path}")
        print(f"   Layout: [Real | Train_mode | Eval_mode]")
        print(f"   Top row: A‚ÜíB, Bottom row: B‚ÜíA")
        
        return diagnosis
        
    except Exception as e:
        print(f"‚ùå Erro durante diagn√≥stico: {e}")
        return "error"

# EXECUTAR O DIAGN√ìSTICO (se modelos j√° estiverem treinados)
if 'G_AB' in globals() and 'test_loader' in globals():
    print("\n" + "="*50)
    print("üî¨ EXECUTANDO DIAGN√ìSTICO AUTOM√ÅTICO")
    print("="*50)
    diagnosis_result = test_visualization_issue()
    print("="*50)
else:
    print("‚ö†Ô∏è Diagn√≥stico dispon√≠vel ap√≥s treinamento. Execute: test_visualization_issue()")

# ================================
# M√©tricas
# ================================

In [None]:
config = config  # ‚Üê Garantir que config esteja acess√≠vel

def evaluate_existing_models(num_samples=100):
    """
    Avalia√ß√£o usando G_AB, G_BA, test_loader e config j√° existentes
    """
    
    print("üöÄ AVALIA√á√ÉO QUANTITATIVA")
    print("=" * 50)
    print(f"üìä Configura√ß√£o:")
    print(f"   Device: {config.DEVICE}")
    print(f"   Samples: {num_samples}")
    print(f"   Test dataset: {len(test_dataset)} amostras")
    
    # Verificar se modelos existem
    try:
        print(f"   G_AB status: {type(G_AB).__name__}")
        print(f"   G_BA status: {type(G_BA).__name__}")
        print(f"   Test loader: {len(test_loader)} batches")
    except NameError as e:
        print(f"‚ùå Erro: {e}")
        print("üí° Execute primeiro o notebook de treinamento")
        return None
    
    # Colocar modelos em modo avalia√ß√£o
    G_AB.eval()
    G_BA.eval()
    
    print(f"\nüîÑ Gerando {num_samples} amostras para avalia√ß√£o...")
    
    # Listas para armazenar amostras
    real_A_samples = []
    real_B_samples = []
    fake_A_samples = []
    fake_B_samples = []
    cycle_A_samples = []
    cycle_B_samples = []
    
    count = 0
    
    with torch.no_grad():
        for batch_idx, batch in enumerate(test_loader):
            if count >= num_samples:
                break
                
            real_A = batch['A'].to(config.DEVICE)
            real_B = batch['B'].to(config.DEVICE)
            
            # Transforma√ß√µes A‚ÜíB e B‚ÜíA
            fake_B = G_AB(real_A)  # Real ‚Üí Cartoon
            fake_A = G_BA(real_B)  # Cartoon ‚Üí Real
            
            # Cycles A‚ÜíB‚ÜíA e B‚ÜíA‚ÜíB
            cycle_A = G_BA(fake_B)  # Real ‚Üí Cartoon ‚Üí Real
            cycle_B = G_AB(fake_A)  # Cartoon ‚Üí Real ‚Üí Cartoon
            
            # Armazenar samples (CPU para economizar VRAM)
            batch_size = real_A.size(0)
            for i in range(min(batch_size, num_samples - count)):
                real_A_samples.append(real_A[i].cpu())
                real_B_samples.append(real_B[i].cpu())
                fake_A_samples.append(fake_A[i].cpu())
                fake_B_samples.append(fake_B[i].cpu())
                cycle_A_samples.append(cycle_A[i].cpu())
                cycle_B_samples.append(cycle_B[i].cpu())
                count += 1
            
            # Progress update
            if batch_idx % 5 == 0:
                progress = (count / num_samples) * 100
                print(f"   Progresso: {count}/{num_samples} ({progress:.1f}%)")
    
    print(f"‚úÖ {count} amostras geradas com sucesso!")
    
    return {
        'real_A': real_A_samples,
        'real_B': real_B_samples,
        'fake_A': fake_A_samples,
        'fake_B': fake_B_samples,
        'cycle_A': cycle_A_samples,
        'cycle_B': cycle_B_samples,
        'count': count
    }

def calculate_metrics(samples):
    """Calcular todas as m√©tricas de avalia√ß√£o"""
    
    print(f"\nüìä CALCULANDO M√âTRICAS")
    print("=" * 30)
    
    l1_loss = torch.nn.L1Loss()
    mse_loss = torch.nn.MSELoss()
    
    # 1. CYCLE CONSISTENCY (m√©trica principal do CycleGAN)
    print("üîÑ Cycle Consistency...")
    
    cycle_loss_A = 0  # A ‚Üí B ‚Üí A
    cycle_loss_B = 0  # B ‚Üí A ‚Üí B
    
    for real_A, cycle_A in zip(samples['real_A'], samples['cycle_A']):
        cycle_loss_A += l1_loss(real_A, cycle_A).item()
    
    for real_B, cycle_B in zip(samples['real_B'], samples['cycle_B']):
        cycle_loss_B += l1_loss(real_B, cycle_B).item()
    
    num_samples = samples['count']
    avg_cycle_A = cycle_loss_A / num_samples
    avg_cycle_B = cycle_loss_B / num_samples
    avg_cycle_total = (cycle_loss_A + cycle_loss_B) / (2 * num_samples)
    
    print(f"   A‚ÜíB‚ÜíA: {avg_cycle_A:.4f}")
    print(f"   B‚ÜíA‚ÜíB: {avg_cycle_B:.4f}")
    print(f"   M√©dia: {avg_cycle_total:.4f}")
    
    # 2. PIXEL-LEVEL SIMILARITY (L1 e MSE)
    print("üñºÔ∏è Pixel-level metrics...")
    
    # L1 entre transforma√ß√µes
    l1_A_to_B = 0
    l1_B_to_A = 0
    mse_A_to_B = 0
    mse_B_to_A = 0
    
    for real_A, fake_B in zip(samples['real_A'], samples['fake_B']):
        l1_A_to_B += l1_loss(real_A, fake_B).item()
        mse_A_to_B += mse_loss(real_A, fake_B).item()
    
    for real_B, fake_A in zip(samples['real_B'], samples['fake_A']):
        l1_B_to_A += l1_loss(real_B, fake_A).item()
        mse_B_to_A += mse_loss(real_B, fake_A).item()
    
    avg_l1_A_to_B = l1_A_to_B / num_samples
    avg_l1_B_to_A = l1_B_to_A / num_samples
    avg_mse_A_to_B = mse_A_to_B / num_samples
    avg_mse_B_to_A = mse_B_to_A / num_samples
    
    print(f"   L1 A‚ÜíB: {avg_l1_A_to_B:.4f}")
    print(f"   L1 B‚ÜíA: {avg_l1_B_to_A:.4f}")
    print(f"   MSE A‚ÜíB: {avg_mse_A_to_B:.4f}")
    print(f"   MSE B‚ÜíA: {avg_mse_B_to_A:.4f}")
    
    # 3. IDENTITY PRESERVATION (se aplic√°vel)
    print("üéØ Identity metrics...")
    
    # Measure of how different A‚ÜíB and B‚ÜíA are (domain transfer effectiveness)
    domain_transfer_A_to_B = avg_l1_A_to_B  # Higher = more transformation
    domain_transfer_B_to_A = avg_l1_B_to_A
    
    print(f"   Domain transfer A‚ÜíB: {domain_transfer_A_to_B:.4f}")
    print(f"   Domain transfer B‚ÜíA: {domain_transfer_B_to_A:.4f}")
    
    return {
        'cycle_consistency': {
            'A_to_B_to_A': avg_cycle_A,
            'B_to_A_to_B': avg_cycle_B,
            'average': avg_cycle_total
        },
        'pixel_similarity': {
            'L1_A_to_B': avg_l1_A_to_B,
            'L1_B_to_A': avg_l1_B_to_A,
            'MSE_A_to_B': avg_mse_A_to_B,
            'MSE_B_to_A': avg_mse_B_to_A
        },
        'domain_transfer': {
            'A_to_B_effectiveness': domain_transfer_A_to_B,
            'B_to_A_effectiveness': domain_transfer_B_to_A,
            'asymmetry_ratio': domain_transfer_B_to_A / domain_transfer_A_to_B
        }
    }

def compare_with_literature(metrics):
    """Comparar resultados com papers publicados"""
    
    print(f"\nüìà COMPARA√á√ÉO COM LITERATURA")
    print("=" * 40)
    
    cycle_consistency = metrics['cycle_consistency']['average']
    
    # Baselines de papers conhecidos (valores aproximados)
    baselines = {
        'CycleGAN Original (Zhu et al., 2017)': {
            'cycle_consistency': 0.15,
            'domain': 'horse2zebra, apple2orange'
        },
        'AttentionGAN (Tang et al., 2019)': {
            'cycle_consistency': 0.12,
            'domain': 'face translation'
        },
        'UNIT (Liu et al., 2017)': {
            'cycle_consistency': 0.18,
            'domain': 'various'
        },
        'Pix2Pix (Isola et al., 2017) - supervised': {
            'cycle_consistency': 0.05,
            'domain': 'paired data'
        }
    }
    
    print(f"üéØ SEU MODELO:")
    print(f"   Cycle Consistency: {cycle_consistency:.4f}")
    print(f"   Domain: Real ‚Üî Cartoon (unpaired)")
    print(f"   M√©todo: CycleGAN + LSGAN + Spectral Norm + TTUR")
    
    print(f"\nüìä BASELINES PUBLICADOS:")
    
    your_rank = 1
    better_than = []
    worse_than = []
    
    for method, data in baselines.items():
        baseline_value = data['cycle_consistency']
        
        if cycle_consistency <= baseline_value:
            status = "‚úÖ SUPERIOR"
            better_than.append(method)
        else:
            status = "‚ö†Ô∏è inferior"
            worse_than.append(method)
            your_rank += 1
        
        print(f"   {status} {method}:")
        print(f"      Cycle: {baseline_value:.4f} | Domain: {data['domain']}")
    
    print(f"\nüèÜ POSI√á√ÉO: {your_rank}/{len(baselines)+1}")
    
    # An√°lise qualitativa
    if cycle_consistency <= 0.08:
        quality = "üéâ EXCEPCIONAL! N√≠vel state-of-the-art!"
    elif cycle_consistency <= 0.12:
        quality = "üèÜ EXCELENTE! Competitivo com os melhores papers!"
    elif cycle_consistency <= 0.15:
        quality = "‚úÖ MUITO BOM! Compar√°vel ao CycleGAN original!"
    elif cycle_consistency <= 0.20:
        quality = "üëç BOM! Resultado aceit√°vel para research!"
    else:
        quality = "‚ö†Ô∏è MODERADO. Margem para melhoria."
    
    print(f"\n{quality}")
    
    if better_than:
        print(f"‚úÖ Superior a: {len(better_than)} m√©todo(s)")
    if worse_than:
        print(f"‚ö†Ô∏è Inferior a: {len(worse_than)} m√©todo(s)")
    
    return {
        'rank': your_rank,
        'total_methods': len(baselines) + 1,
        'better_than': better_than,
        'worse_than': worse_than,
        'quality_assessment': quality
    }

def save_evaluation_results(samples, metrics, comparison, model_info=None):
    """Salvar todos os resultados da avalia√ß√£o"""
    
    print(f"\nüíæ SALVANDO RESULTADOS")
    print("=" * 25)
    
    # Compilar todos os resultados
    results = {
        'metadata': {
            'timestamp': datetime.now().isoformat(),
            'evaluation_type': 'quantitative_analysis',
            'num_samples': samples['count'],
            'model_path': config.MODEL_SAVE_PATH,
            'device': str(config.DEVICE)
        },
        'model_configuration': {
            'lambda_cycle': config.LAMBDA_CYCLE,
            'lambda_identity': config.LAMBDA_IDENTITY,
            'lr_g': config.LR_G,
            'lr_d': config.LR_D,
            'num_residual_blocks': config.NUM_RESIDUAL_BLOCKS,
            'use_spectral_norm': True,
            'use_ttur': True,
            'use_lsgan': True,
            'architecture_improvements': [
                'Spectral Normalization',
                'TTUR Learning Rates', 
                'Upsample + Conv (no deconvolution)',
                'Label Smoothing',
                'Reflection Padding'
            ]
        },
        'quantitative_metrics': metrics,
        'literature_comparison': comparison
    }
    
    # Salvar em JSON
    results_path = f"{config.MODEL_SAVE_PATH}/quantitative_evaluation.json"
    with open(results_path, 'w') as f:
        json.dump(results, f, indent=2, ensure_ascii=False)
    
    print(f"‚úÖ Resultados salvos em:")
    print(f"   {results_path}")
    
    # Criar relat√≥rio de texto simples
    report_path = f"{config.MODEL_SAVE_PATH}/evaluation_report.txt"
    with open(report_path, 'w') as f:
        f.write("RELAT√ìRIO DE AVALIA√á√ÉO QUANTITATIVA\n")
        f.write("="*50 + "\n\n")
        f.write(f"Data: {datetime.now().strftime('%Y-%m-%d %H:%M:%S')}\n")
        f.write(f"Amostras avaliadas: {samples['count']}\n\n")
        
        f.write("M√âTRICAS PRINCIPAIS:\n")
        f.write(f"Cycle Consistency: {metrics['cycle_consistency']['average']:.4f}\n")
        f.write(f"L1 A‚ÜíB: {metrics['pixel_similarity']['L1_A_to_B']:.4f}\n")
        f.write(f"L1 B‚ÜíA: {metrics['pixel_similarity']['L1_B_to_A']:.4f}\n\n")
        
        f.write(f"POSI√á√ÉO NA LITERATURA: {comparison['rank']}/{comparison['total_methods']}\n")
        f.write(f"Superior a {len(comparison['better_than'])} m√©todo(s)\n")
        f.write(f"Avalia√ß√£o: {comparison['quality_assessment']}\n")
    
    print(f"üìÑ Relat√≥rio de texto salvo em:")
    print(f"   {report_path}")
    
    return results_path

def run_complete_evaluation():
    """FUN√á√ÉO PRINCIPAL - Execute esta para avalia√ß√£o completa"""
    
    print("üéØ AVALIA√á√ÉO QUANTITATIVA COMPLETA")
    print("üî¨ CycleGAN Real2Cartoon - An√°lise Final")
    print("=" * 60)
    
    try:
        # 1. Gerar amostras usando modelos existentes
        samples = evaluate_existing_models(num_samples=100)
        if samples is None:
            return None
        
        # 2. Calcular m√©tricas quantitativas
        metrics = calculate_metrics(samples)
        
        # 3. Comparar com literatura
        comparison = compare_with_literature(metrics)
        
        # 4. Salvar resultados
        results_path = save_evaluation_results(samples, metrics, comparison)
        
        print(f"\nüéâ AVALIA√á√ÉO COMPLETA CONCLU√çDA!")
        print(f"üìä Cycle Consistency: {metrics['cycle_consistency']['average']:.4f}")
        print(f"üèÜ Ranking: {comparison['rank']}/{comparison['total_methods']}")
        print(f"üíæ Resultados: {results_path}")
        
        return {
            'samples': samples,
            'metrics': metrics,
            'comparison': comparison,
            'results_path': results_path
        }
        
    except Exception as e:
        print(f"‚ùå Erro durante avalia√ß√£o: {e}")
        import traceback
        traceback.print_exc()
        return None

In [None]:
evaluation_results = run_complete_evaluation()