In [None]:
import torch.nn as nn
import torch.nn.functional as F
import torch
import os
import numpy as np
import itertools
import datetime
import time
import torch.nn as nn
import torch
# from torchvision.utils import save_image
import sys 
import warnings
warnings.filterwarnings("ignore", category=DeprecationWarning)
from torch.utils.data import Dataset
from torch.utils.data import DataLoader



def weights_init_normal(m):
    classname = m.__class__.__name__
    if classname.find("Conv") != -1:
        torch.nn.init.normal_(m.weight.data, 0.0, 0.02)
    elif classname.find("BatchNorm2d") != -1:
        torch.nn.init.normal_(m.weight.data, 1.0, 0.02)
        torch.nn.init.constant_(m.bias.data, 0.0)


##############################
#           U-NET
##############################



# === Instance Normalization Custom (como no TF) ===
class InstanceNormalization(nn.Module):
    def __init__(self, epsilon=1e-5):
        super().__init__()
        self.epsilon = epsilon
        # escala e offset serão inicializados no forward com parâmetros registrados

    def forward(self, x):
        # x shape: (N,C,H,W)
        mean = x.mean(dim=[2,3], keepdim=True)
        var = x.var(dim=[2,3], keepdim=True, unbiased=False)
        inv = 1.0 / torch.sqrt(var + self.epsilon)
        normalized = (x - mean) * inv

        # Criar escala e offset param se não existirem
        if not hasattr(self, 'scale'):
            self.scale = nn.Parameter(torch.ones(x.size(1), device=x.device))
            self.offset = nn.Parameter(torch.zeros(x.size(1), device=x.device))
        # reshape para broadcast
        scale = self.scale.view(1, -1, 1, 1)
        offset = self.offset.view(1, -1, 1, 1)
        return scale * normalized + offset

# === Downsample e Upsample ===
def downsample(in_ch, out_ch, norm_type='instancenorm', apply_norm=True):
    layers = [nn.Conv2d(in_ch, out_ch, kernel_size=4, stride=2, padding=1, bias=False)]
    if apply_norm:
        if norm_type == 'batchnorm':
            layers.append(nn.BatchNorm2d(out_ch))
        elif norm_type == 'instancenorm':
            layers.append(InstanceNormalization())
    layers.append(nn.LeakyReLU(0.2, inplace=True))
    return nn.Sequential(*layers)

def upsample(in_ch, out_ch, norm_type='instancenorm', apply_dropout=False):
    layers = [nn.ConvTranspose2d(in_ch, out_ch, kernel_size=4, stride=2, padding=1, bias=False)]
    if norm_type == 'batchnorm':
        layers.append(nn.BatchNorm2d(out_ch))
    elif norm_type == 'instancenorm':
        layers.append(InstanceNormalization())
    layers.append(nn.ReLU(inplace=True))
    if apply_dropout:
        layers.append(nn.Dropout(0.5))
    return nn.Sequential(*layers)


class UNetGenerator(nn.Module):
    def __init__(self, input_channels=3, output_channels=3, norm_type='instancenorm', target_size=256):
        super().__init__()
        self.target_size = target_size
        self.down1 = downsample(input_channels, 64, norm_type, apply_norm=False)
        self.down2 = downsample(64, 128, norm_type)
        self.down3 = downsample(128, 256, norm_type)
        self.down4 = downsample(256, 512, norm_type)
        self.down5 = downsample(512, 512, norm_type)
        self.down6 = downsample(512, 512, norm_type)
        self.down7 = downsample(512, 512, norm_type)
        self.down8 = downsample(512, 512, norm_type)

        self.up1 = upsample(512, 512, norm_type, apply_dropout=True)
        self.up2 = upsample(1024, 512, norm_type, apply_dropout=True)
        self.up3 = upsample(1024, 512, norm_type, apply_dropout=True)
        self.up4 = upsample(1024, 512, norm_type)
        self.up5 = upsample(1024, 256, norm_type)
        self.up6 = upsample(512, 128, norm_type)
        self.up7 = upsample(256, 64, norm_type)

        self.final = nn.ConvTranspose2d(128, output_channels, kernel_size=4, stride=2, padding=1)
        self.tanh = nn.Tanh()

    def forward(self, x):
        orig_size = x.shape[-2:]  # (H, W)

        # Upsample entrada para target_size x target_size
        x = F.interpolate(x, size=(self.target_size, self.target_size), mode='bilinear', align_corners=False)

        d1 = self.down1(x)
        d2 = self.down2(d1)
        d3 = self.down3(d2)
        d4 = self.down4(d3)
        d5 = self.down5(d4)
        d6 = self.down6(d5)
        d7 = self.down7(d6)
        d8 = self.down8(d7)

        u1 = self.up1(d8)
        u1 = torch.cat([u1, d7], dim=1)

        u2 = self.up2(u1)
        u2 = torch.cat([u2, d6], dim=1)

        u3 = self.up3(u2)
        u3 = torch.cat([u3, d5], dim=1)

        u4 = self.up4(u3)
        u4 = torch.cat([u4, d4], dim=1)

        u5 = self.up5(u4)
        u5 = torch.cat([u5, d3], dim=1)

        u6 = self.up6(u5)
        u6 = torch.cat([u6, d2], dim=1)

        u7 = self.up7(u6)
        u7 = torch.cat([u7, d1], dim=1)

        output = self.final(u7)
        output = self.tanh(output)

        # Downsample a saída para o tamanho original da entrada
        output = F.interpolate(output, size=orig_size, mode='bilinear', align_corners=False)

        return output

# === PatchGAN Discriminator ===
class PatchDiscriminator(nn.Module):
    def __init__(self, input_channels=3, norm_type='batchnorm', target=True):
        super().__init__()
        self.target = target
        in_ch = input_channels * 2 if target else input_channels

        self.model = nn.Sequential(
            nn.Conv2d(in_ch, 64, kernel_size=4, stride=2, padding=1),  # no norm
            nn.LeakyReLU(0.2, inplace=True),

            downsample(64, 128, norm_type),
            downsample(128, 256, norm_type),

            nn.ZeroPad2d(1),
            nn.Conv2d(256, 512, kernel_size=4, stride=1, padding=0, bias=False),

            nn.BatchNorm2d(512) if norm_type == 'batchnorm' else InstanceNormalization(),
            nn.LeakyReLU(0.2, inplace=True),

            nn.ZeroPad2d(1),
            nn.Conv2d(512, 1, kernel_size=4, stride=1, padding=0)
        )

    def forward(self, inp, target=None):
        if self.target and target is not None:
            x = torch.cat([inp, target], dim=1)
        else:
            x = inp
        return self.model(x)


In [None]:


def asmape(y_true, y_pred, mask=None):
    if mask is not None:
         y_true, y_pred = y_true[mask==1], y_pred[mask==1]
    if type(y_true) is list or type(y_pred) is list:
         y_true, y_pred = np.array(y_true), np.array(y_pred)
    len_ = len(y_true)
    tmp = 100 * (np.nansum(np.abs(y_pred - y_true) / (np.abs(y_true) + np.abs(y_pred)))/len_)

    return tmp

class LoaderDataset(Dataset):
    def __init__(self, root_zebra, root_horse, root_masks, chanels=3):
        self.root_zebra = root_zebra
        self.root_horse = root_horse
        self.root_index = root_masks
        
        self.zebra_images = sorted(os.listdir(root_zebra))
        self.horse_images = sorted(os.listdir(root_horse))
        self.index = sorted(os.listdir(root_masks))

        self.length_dataset = max(len(self.zebra_images), len(self.horse_images))
        self.zebra_len = len(self.zebra_images)
        self.horse_len = len(self.horse_images)
        self.index_len = len(self.index)
        self.chanels = chanels

    def __len__(self):
        return self.length_dataset

    @staticmethod
    def custom_normalize(image):
        image = torch.tensor(image, dtype=torch.float32)
        min_val = torch.min(image)
        max_val = torch.max(image)
        scale = torch.clamp(max_val - min_val, min=1e-5)  # Evita divisão por zero
        image_normalized = 2 * (image - min_val) / scale - 1  # Escala para [-1, 1]
        return image_normalized, min_val, max_val

    def __getitem__(self, index):
        zebra_img = self.zebra_images[index % self.zebra_len]
        horse_img = self.horse_images[index % self.horse_len]
        index_ids = self.index[index % self.index_len]

        zebra_path = os.path.join(self.root_zebra, zebra_img)
        horse_path = os.path.join(self.root_horse, horse_img)
        index_path = os.path.join(self.root_index, index_ids)
        # print(zebra_path, horse_path, index_path)

        zebra_img = np.load(zebra_path)
        horse_img = np.load(horse_path)
        mask = np.load(index_path)

        if len(zebra_img.shape) > 3:
            zebra_img = zebra_img.reshape(32, 32, 3)
            horse_img = horse_img.reshape(32, 32, 3)

        zebra_img = np.transpose(zebra_img, (2, 0, 1))
        horse_img = np.transpose(horse_img, (2, 0, 1))

        if self.chanels == 2:
            zebra_img = zebra_img[:2, :, :]
            horse_img = horse_img[:2, :, :]
        elif self.chanels == 1:
            zebra_img = np.sum(zebra_img, axis=0, keepdims=True)
            horse_img = np.sum(horse_img, axis=0, keepdims=True)

        zebra_img, min_val_z, max_val_z = LoaderDataset.custom_normalize(zebra_img)
        horse_img, _, _ = LoaderDataset.custom_normalize(horse_img)

        mask = torch.tensor(mask, dtype=torch.float32)

        return zebra_img, horse_img, min_val_z, max_val_z, mask

import pandas as pd 
from sklearn.metrics import mean_absolute_percentage_error as mape, mean_absolute_error as mae, mean_squared_error as mse

import sys
import time
import datetime


def add_masked_gaussian_noise(x: torch.Tensor, mask: torch.Tensor, sigma: float) -> torch.Tensor:
    # x: [B,C,H,W], mask: [B,1,H,W]
    inv = (1.0 - mask).expand(-1, x.size(1), -1, -1)
    noise = torch.randn_like(x) * sigma
    return x + inv * noise



def train_fn(dataloader, G_AB, G_BA, D_A, D_B, opt_G, opt_D_A, opt_D_B,
             pixelwise_loss, cycle_loss, adversarial_loss, val_loader, scaler,
             patience, save_dir, NUM_EPOCHS=1000):
    
    def safe_save(model, path):
        tmp_path = path + ".tmp"
        torch.save(model.state_dict(), tmp_path)
        os.replace(tmp_path, path)
    
    def james_stein_reduce(errors: torch.Tensor) -> torch.Tensor:
        mean = torch.mean(errors)
        var = torch.var(errors, unbiased=False)
        norm_sq = torch.sum((errors - mean) ** 2)
        dim = errors.numel()
        shrinkage = torch.clamp(1 - ((dim - 2) * var / (norm_sq + 1e-8)), min=0.0)
        return mean + shrinkage * (errors - mean).mean()
    
    initial_best_loss = float('inf')

    G_AB.train()
    G_BA.train()
    D_A.train()
    D_B.train()

    total_loss = 0
    device = next(G_AB.parameters()).device
    prev_time = time.time()
    for epoch in range(NUM_EPOCHS):
        H_reals = 0
        H_fakes = 0
        device = "cuda" if torch.cuda.is_available() else "cpu"
        initial_sigma = 0.1
        final_sigma = 0.0
        warmup_epochs = 100
        sigma = max(0.0, initial_sigma * (1 - epoch / warmup_epochs))
        
        for i, (real_A, real_B, _, _, masks) in enumerate(dataloader):
            real_A = real_A.to(device)
            real_B = real_B.to(device)
            masks = masks.to(device).view(-1, 1, 32, 32).float()
            # Adicionar ruído gaussiano às regiões mascaradas
            real_A = add_masked_gaussian_noise(real_A, masks, sigma=sigma)
            real_B = add_masked_gaussian_noise(real_B, masks, sigma=sigma)


            # Criar tensores valid e fake usando a saída do discriminador com input e target reais
            valid = torch.ones_like(D_A(real_A * masks, real_A * masks))
            fake = torch.zeros_like(valid)

            # Train Generators
            opt_G.zero_grad()
            with torch.amp.autocast('cuda'):
                fake_B = G_AB(real_A)
                fake_A = G_BA(real_B)

                pred_fake_B = D_B(fake_B * masks, real_B * masks)
                pred_fake_A = D_A(fake_A * masks, real_A * masks)

                loss_GAN_AB = adversarial_loss(pred_fake_B, valid)
                loss_GAN_BA = adversarial_loss(pred_fake_A, valid)

                recov_A = G_BA(fake_B)
                recov_B = G_AB(fake_A)

                loss_cycle_A = cycle_loss(recov_A * masks, real_A * masks)
                loss_cycle_B = cycle_loss(recov_B * masks, real_B * masks)

                loss_pixelwise = (pixelwise_loss(fake_A * masks, real_A * masks) +
                                	pixelwise_loss(fake_B * masks, real_B * masks)) / 2

                loss_G_raw = loss_GAN_AB + loss_GAN_BA + loss_cycle_A + loss_cycle_B + loss_pixelwise
                loss_G = james_stein_reduce(loss_G_raw.flatten())

            scaler.scale(loss_G).backward()
            scaler.step(opt_G)
            scaler.update()

            # Train Discriminator A
            opt_D_A.zero_grad()
            with torch.amp.autocast('cuda'):
                pred_real = D_A(real_A * masks, real_A * masks)
                pred_fake = D_A(fake_A.detach() * masks, real_A * masks)
                loss_D_A = (adversarial_loss(pred_real, valid) + adversarial_loss(pred_fake, fake)) / 2
            scaler.scale(loss_D_A).backward()
            scaler.step(opt_D_A)

            # Train Discriminator B
            opt_D_B.zero_grad()
            with torch.amp.autocast('cuda'):
                pred_real = D_B(real_B * masks, real_B * masks)
                pred_fake = D_B(fake_B.detach() * masks, real_B * masks)
                loss_D_B = (adversarial_loss(pred_real, valid) + adversarial_loss(pred_fake, fake)) / 2
            scaler.scale(loss_D_B).backward()
            scaler.step(opt_D_B)

            total_loss += loss_G.item()

            batches_done = epoch * len(dataloader) + i + 1
            batches_left = NUM_EPOCHS * len(dataloader) - batches_done
            time_left = datetime.timedelta(seconds=int(batches_left * (time.time() - prev_time)))
            prev_time = time.time()

            sys.stdout.write(
                f"\r[Epoch {epoch+1}/{NUM_EPOCHS}] [Batch {i+1}/{len(dataloader)}] "
                f"[D loss: {(loss_D_A.item() + loss_D_B.item())/2:.6f}] "
                f"[G loss: {loss_G.item():.6f}, adv: {(loss_GAN_AB.item() + loss_GAN_BA.item())/2:.6f}, "
                f"pixel: {loss_pixelwise.item():.6f}, cycle: {(loss_cycle_A.item() + loss_cycle_B.item())/2:.6f}] "
                f"ETA: {time_left}   "
            )
            sys.stdout.flush()

        avg_loss = total_loss / len(dataloader)

        # Validação
        val_loss = 0
        G_AB.eval()
        G_BA.eval()
        with torch.no_grad():
            for real_A, real_B, _, _, masks in val_loader:
                real_A = real_A.to(device)
                real_B = real_B.to(device)
                masks = masks.to(device).view(-1, 1, 32, 32).float()
                # Adicionar ruído gaussiano às regiões mascaradas
                real_A = add_masked_gaussian_noise(real_A, masks, sigma=sigma)
                real_B = add_masked_gaussian_noise(real_B, masks, sigma=sigma)

                fake_A = G_BA(real_B)
                fake_B = G_AB(real_A)

                loss_val = (pixelwise_loss(fake_A * masks, real_A * masks) +
                            pixelwise_loss(fake_B * masks, real_B * masks)) / 2
                val_loss += loss_val.item()

        avg_val_loss = val_loss / len(val_loader)

        

        if initial_best_loss > avg_val_loss:
          initial_best_loss = avg_val_loss
          epochs_no_improve = 0
        if initial_best_loss < avg_val_loss:
           epochs_no_improve += 1
        if epochs_no_improve >= patience:
          print(f"\nEarly stopping at epoch {epoch+1} with best loss {initial_best_loss:.6f}")
          break 
        
        if os.path.exists(save_dir) == False:
          os.makedirs(save_dir)
        
        model_path_ab = f"{save_dir}/G_AB.pth"
        model_path_ba = f"{save_dir}/G_BA.pth"

        if os.path.exists(model_path_ab) and os.path.exists(model_path_ba):
            G_AB.load_state_dict(torch.load(model_path_ab, weights_only=True))
            G_BA.load_state_dict(torch.load(model_path_ba, weights_only=True))
        else:
            pass
        safe_save(G_AB, model_path_ab)
        safe_save(G_BA, model_path_ba)






def test(gen_Z, test_loader, taxa, fold, chanells,DEVICE):

		gen_Z.eval()
		
		with torch.no_grad():
			# Criar o DataFrame com as colunas desejadas
			df = pd.DataFrame([], columns=['mae', 'asmape', 'mape', 'rmse', 'scale'], index=test_loader.dataset.horse_images)

			for (zebra, horse, std_val, mean_val, masks), name in zip(test_loader, test_loader.dataset.horse_images):
					# Verificar as dimensões das entradas
					
					# Mover dados para o dispositivo
					zebra = zebra.to(DEVICE)
					horse = horse.to(DEVICE)

					# Converter std_val e mean_val para tensores e movê-los para o dispositivo
					std_val = torch.tensor(std_val, device=DEVICE) if not isinstance(std_val, torch.Tensor) else std_val.to(DEVICE)
					mean_val = torch.tensor(mean_val, device=DEVICE) if not isinstance(mean_val, torch.Tensor) else mean_val.to(DEVICE)

					# Gerar fake_zebra usando o gerador
					fake_zebra = gen_Z(horse)

					# Mover apenas as imagens para a CPU antes de operações subsequentes
					zebra = zebra.cpu()
					fake_zebra = fake_zebra.cpu()

					# Voltar para escala original 
					zebra = zebra * std_val.cpu() + mean_val.cpu()
					fake_zebra = fake_zebra * std_val.cpu() + mean_val.cpu()

					# Somar sobre o canal e achatar as imagens
					zebra = torch.sum(zebra, dim=1).flatten()*masks
					fake_zebra = torch.sum(fake_zebra, dim=1).flatten()*masks

					# Calcular as métricas
					mae_value = round(mae(zebra, fake_zebra), 3)
					mape_value = round(mape(zebra, fake_zebra) * 100, 3)
					rmse_value = round(np.sqrt(mse(zebra, fake_zebra)), 3)
					smape_value = round(asmape(zebra, fake_zebra, masks), 3)

					# Adicionar os resultados ao DataFrame
					df.loc[name] = [mae_value,smape_value , mape_value, rmse_value, np.max(zebra.numpy()) - np.min(zebra.numpy())]

			# Salvar o DataFrame em um arquivo CSV
			directory = "./resultados/resultados_discogan"
			if not os.path.exists(directory):
					os.makedirs(directory)

			df.to_csv(os.path.join(directory, f'result_{str(chanells)}c_{taxa}_{fold}.csv'))




# ----------
#  Treinamento
# ----------
# Parâmetros de treinamento

class EarlyStopping:
    def __init__(self, patience=100, verbose=False):
        self.patience = patience
        self.counter = 0
        self.best_loss = None
        self.early_stop = False
        self.verbose = verbose
        self.best_epoch = 0

    def __call__(self, val_loss, epoch=None):
        if self.best_loss is None or val_loss < self.best_loss:
            self.best_loss = val_loss
            self.counter = 0
            if epoch is not None:
                self.best_epoch = epoch
        else:
            self.counter += 1
            if self.verbose:
                if epoch is not None:
                    print(f"Epoch {epoch}: EarlyStopping counter = {self.counter}/{self.patience}")
                else:
                    print(f"EarlyStopping: {self.counter}/{self.patience}")
            if self.counter >= self.patience:
                self.early_stop = True



def main(in_channels):

    scal = torch.amp.GradScaler(True)
    TRAIN_DIR = os.path.abspath("../dataset_final")  
    VAL_DIR = os.path.abspath("../dataset_final")  
    INDEX_TRAIN = os.path.abspath("../dataset_final")  
    INDEX_VAL = os.path.abspath("../dataset_final")  
    INDEX_TEST = os.path.abspath("../dataset_final")    
    BATCH_SIZE = 16
    NUM_EPOCHS = 1000

    lrd= 1e-4
    lrg= 1e-3
    b1 =0.5
    b2= 0.999
    
    
    device = "cuda:0" if torch.cuda.is_available() else "cpu"

    for taxa in ['10', '20', '30', '40']:
        for fold in ['1', '2', '3', '4', '5']:
            print(f"Treinando taxa={taxa}, fold={fold}")
            save_dir = f"./models_saved/discogan/{in_channels}/{taxa}/fold{fold}"        

            # Inicializar perdas
            adversarial_loss = torch.nn.MSELoss().to(device)
            cycle_loss = torch.nn.L1Loss().to(device)
            pixelwise_loss = torch.nn.L1Loss().to(device)

            # Inicializar modelos
            G_AB = UNetGenerator(input_channels=in_channels, output_channels=in_channels).to(device)
            G_BA = UNetGenerator(input_channels=in_channels, output_channels=in_channels).to(device)
            D_A = PatchDiscriminator(input_channels=in_channels).to(device)
            D_B = PatchDiscriminator(input_channels=in_channels).to(device)

            # Inicializar pesos
            G_AB.apply(weights_init_normal)
            G_BA.apply(weights_init_normal)
            D_A.apply(weights_init_normal)
            D_B.apply(weights_init_normal)

            # Otimizadores
            optimizer_G = torch.optim.Adam(itertools.chain(G_AB.parameters(), G_BA.parameters()), lr=1e-4, betas=(b1, b2))
            optimizer_D_A = torch.optim.Adam(D_A.parameters(), lr=lrd, betas=(b1, b2))
            optimizer_D_B = torch.optim.Adam(D_B.parameters(), lr=lrg, betas=(b1, b2))

            # Carregar datasets e loaders
            train_dataset = LoaderDataset(
                root_zebra=os.path.join(TRAIN_DIR, "label", taxa, "folds", f"fold{fold}", "train"),
                root_horse=os.path.join(TRAIN_DIR, "input", taxa, "folds", f"fold{fold}", "train"),
                root_masks=os.path.join(INDEX_TRAIN, "input", taxa, "folds", f"fold{fold}", "index_train"),
                chanels=in_channels
            )
            val_dataset = LoaderDataset(
                root_zebra=os.path.join(VAL_DIR, "label", taxa, "folds", f"fold{fold}", "val"),
                root_horse=os.path.join(VAL_DIR, "input", taxa, "folds", f"fold{fold}", "val"),
                root_masks=os.path.join(INDEX_VAL, "input", taxa, "folds", f"fold{fold}", "index_val"),
                chanels=in_channels
            )
            test_dataset = LoaderDataset(
                root_zebra=os.path.join(VAL_DIR, "label", taxa, "folds", f"fold{fold}", "test"),
                root_horse=os.path.join(VAL_DIR, "input", taxa, "folds", f"fold{fold}", "test"),
                root_masks=os.path.join(INDEX_TEST, "input", taxa, "folds", f"fold{fold}", "index"),
                chanels=in_channels
            )

            train_loader = DataLoader(train_dataset, batch_size=BATCH_SIZE, shuffle=True, pin_memory=(device=="cuda"))
            val_loader = DataLoader(val_dataset, batch_size=BATCH_SIZE, shuffle=False, pin_memory=(device=="cuda"))
            test_loader = DataLoader(test_dataset, batch_size=1, shuffle=False, pin_memory=(device=="cuda"))

            scaler = torch.amp.GradScaler('cuda') if device == "cuda" else None
            best_val_loss = float('inf')
            epochs_no_improve = 0

            
            train_fn(
                    train_loader, G_AB, G_BA, D_A, D_B,
                    optimizer_G, optimizer_D_A, optimizer_D_B,
                    pixelwise_loss, cycle_loss, adversarial_loss,
                    val_loader, scal,
                    50,
                    save_dir=save_dir, 
                    NUM_EPOCHS=NUM_EPOCHS
                )
           

            # Salvar modelo final depois do treino
            save_dir = f"./models_saved/discogan/{in_channels}/{taxa}/fold{fold}"
            os.makedirs(save_dir, exist_ok=True)
            # model_path = os.path.join(save_dir, "generator_final.pth")
            # torch.save(G_BA.state_dict(), model_path)
            # print(f"Modelo salvo em: {model_path}")

            # Teste
            test(G_BA, test_loader, taxa, fold, in_channels,device)

import gc
if __name__ == '__main__':
    for i in [1, 2, 3]:
        print(f"\n\n\nIniciando treinamento com {i} canais\n\n\n")
        main(i)
        # Libera memória (GPU e CPU)
        torch.cuda.empty_cache()
        

        




Iniciando treinamento com 1 canais



Treinando taxa=10, fold=1
[Epoch 102/1000] [Batch 65/65] [D loss: 0.002819] [G loss: 2.180770, adv: 1.057216, pixel: 0.024692, cycle: 0.020823] ETA: 0:37:26    
Early stopping at epoch 102 with best loss 0.016918
Treinando taxa=10, fold=2
[Epoch 109/1000] [Batch 65/65] [D loss: 0.002413] [G loss: 2.164432, adv: 1.041265, pixel: 0.029739, cycle: 0.026081] ETA: 0:38:39     
Early stopping at epoch 109 with best loss 0.015440
Treinando taxa=10, fold=3
[Epoch 79/1000] [Batch 65/65] [D loss: 0.037119] [G loss: 1.332749, adv: 0.639435, pixel: 0.016326, cycle: 0.018777] ETA: 0:39:55         
Early stopping at epoch 79 with best loss 0.017438
Treinando taxa=10, fold=4
[Epoch 105/1000] [Batch 65/65] [D loss: 0.033095] [G loss: 2.204990, adv: 1.078011, pixel: 0.015839, cycle: 0.016564] ETA: 0:38:54        
Early stopping at epoch 105 with best loss 0.016711
Treinando taxa=10, fold=5
[Epoch 75/1000] [Batch 65/65] [D loss: 0.003379] [G loss: 2.174751, adv: 

: 