In [None]:
import torch
import torch.nn as nn
import numpy as np
from sklearn.metrics import mean_absolute_percentage_error as mape, mean_absolute_error as mae, mean_squared_error as mse
import pandas as pd
from torch.utils.data import DataLoader
from torch.utils.data import Dataset
import torch.optim as optim
import os
import sys
import torch.nn.functional as F


def asmape(y_true, y_pred, mask=None):
    if mask is not None:
         y_true, y_pred = y_true[mask==1], y_pred[mask==1]
    if type(y_true) is list or type(y_pred) is list:
         y_true, y_pred = np.array(y_true), np.array(y_pred)
    len_ = len(y_true)
    tmp = 100 * (np.nansum(np.abs(y_pred - y_true) / (np.abs(y_true) + np.abs(y_pred)))/len_)

    return tmp


class LoaderDataset(Dataset):
    def __init__(self, root_zebra, root_horse, root_masks, chanels=3):
        self.root_zebra = root_zebra
        self.root_horse = root_horse
        self.root_index = root_masks
        
        self.zebra_images = sorted(os.listdir(root_zebra))
        self.horse_images = sorted(os.listdir(root_horse))
        self.index = sorted(os.listdir(root_masks))

        self.length_dataset = max(len(self.zebra_images), len(self.horse_images))
        self.zebra_len = len(self.zebra_images)
        self.horse_len = len(self.horse_images)
        self.index_len = len(self.index)
        self.chanels = chanels

    def __len__(self):
        return self.length_dataset

    @staticmethod
    def custom_normalize(image):
        image = torch.tensor(image, dtype=torch.float32)
        min_val = torch.min(image)
        max_val = torch.max(image)
        scale = torch.clamp(max_val - min_val, min=1e-5)  # Evita divisão por zero
        image_normalized = 2 * (image - min_val) / scale - 1  # Escala para [-1, 1]
        return image_normalized, min_val, max_val

    def __getitem__(self, index):
        zebra_img = self.zebra_images[index % self.zebra_len]
        horse_img = self.horse_images[index % self.horse_len]
        index_ids = self.index[index % self.index_len]

        zebra_path = os.path.join(self.root_zebra, zebra_img)
        horse_path = os.path.join(self.root_horse, horse_img)
        index_path = os.path.join(self.root_index, index_ids)
        # print(zebra_path, horse_path, index_path)

        zebra_img = np.load(zebra_path)
        horse_img = np.load(horse_path)
        mask = np.load(index_path)

        if len(zebra_img.shape) > 3:
            zebra_img = zebra_img.reshape(32, 32, 3)
            horse_img = horse_img.reshape(32, 32, 3)

        zebra_img = np.transpose(zebra_img, (2, 0, 1))
        horse_img = np.transpose(horse_img, (2, 0, 1))

        if self.chanels == 2:
            zebra_img = zebra_img[:2, :, :]
            horse_img = horse_img[:2, :, :]
        elif self.chanels == 1:
            zebra_img = np.sum(zebra_img, axis=0, keepdims=True)
            horse_img = np.sum(horse_img, axis=0, keepdims=True)

        zebra_img, min_val_z, max_val_z = LoaderDataset.custom_normalize(zebra_img)
        horse_img, _, _ = LoaderDataset.custom_normalize(horse_img)

        mask = torch.tensor(mask, dtype=torch.float32)

        return zebra_img, horse_img, min_val_z, max_val_z, mask
    



# class ToFloat32:
#     def __call__(self, image, **kwargs):
#         return image.float()




# === Instance Normalization Custom (como no TF) ===
class InstanceNormalization(nn.Module):
    def __init__(self, epsilon=1e-5):
        super().__init__()
        self.epsilon = epsilon
        # escala e offset serão inicializados no forward com parâmetros registrados

    def forward(self, x):
        # x shape: (N,C,H,W)
        mean = x.mean(dim=[2,3], keepdim=True)
        var = x.var(dim=[2,3], keepdim=True, unbiased=False)
        inv = 1.0 / torch.sqrt(var + self.epsilon)
        normalized = (x - mean) * inv

        # Criar escala e offset param se não existirem
        if not hasattr(self, 'scale'):
            self.scale = nn.Parameter(torch.ones(x.size(1), device=x.device))
            self.offset = nn.Parameter(torch.zeros(x.size(1), device=x.device))
        # reshape para broadcast
        scale = self.scale.view(1, -1, 1, 1)
        offset = self.offset.view(1, -1, 1, 1)
        return scale * normalized + offset

# === Downsample e Upsample ===
def downsample(in_ch, out_ch, norm_type='batchnorm', apply_norm=True):
    layers = [nn.Conv2d(in_ch, out_ch, kernel_size=4, stride=2, padding=1, bias=False)]
    if apply_norm:
        if norm_type == 'batchnorm':
            layers.append(nn.BatchNorm2d(out_ch))
        elif norm_type == 'instancenorm':
            layers.append(InstanceNormalization())
    layers.append(nn.LeakyReLU(0.2, inplace=True))
    return nn.Sequential(*layers)

def upsample(in_ch, out_ch, norm_type='batchnorm', apply_dropout=False):
    layers = [nn.ConvTranspose2d(in_ch, out_ch, kernel_size=4, stride=2, padding=1, bias=False)]
    if norm_type == 'batchnorm':
        layers.append(nn.BatchNorm2d(out_ch))
    elif norm_type == 'instancenorm':
        layers.append(InstanceNormalization())
    layers.append(nn.ReLU(inplace=True))
    if apply_dropout:
        layers.append(nn.Dropout(0.5))
    return nn.Sequential(*layers)


class UNetGenerator(nn.Module):
    def __init__(self, input_channels=3, output_channels=3, norm_type='batchnorm', target_size=256):
        super().__init__()
        self.target_size = target_size
        self.down1 = downsample(input_channels, 64, norm_type, apply_norm=False)
        self.down2 = downsample(64, 128, norm_type)
        self.down3 = downsample(128, 256, norm_type)
        self.down4 = downsample(256, 512, norm_type)
        self.down5 = downsample(512, 512, norm_type)
        self.down6 = downsample(512, 512, norm_type)
        self.down7 = downsample(512, 512, norm_type)
        self.down8 = downsample(512, 512, norm_type)

        self.up1 = upsample(512, 512, norm_type, apply_dropout=True)
        self.up2 = upsample(1024, 512, norm_type, apply_dropout=True)
        self.up3 = upsample(1024, 512, norm_type, apply_dropout=True)
        self.up4 = upsample(1024, 512, norm_type)
        self.up5 = upsample(1024, 256, norm_type)
        self.up6 = upsample(512, 128, norm_type)
        self.up7 = upsample(256, 64, norm_type)

        self.final = nn.ConvTranspose2d(128, output_channels, kernel_size=4, stride=2, padding=1)
        self.tanh = nn.Tanh()

    def forward(self, x):
        orig_size = x.shape[-2:]  # (H, W)

        # Upsample entrada para target_size x target_size
        x = F.interpolate(x, size=(self.target_size, self.target_size), mode='bilinear', align_corners=False)

        d1 = self.down1(x)
        d2 = self.down2(d1)
        d3 = self.down3(d2)
        d4 = self.down4(d3)
        d5 = self.down5(d4)
        d6 = self.down6(d5)
        d7 = self.down7(d6)
        d8 = self.down8(d7)

        u1 = self.up1(d8)
        u1 = torch.cat([u1, d7], dim=1)

        u2 = self.up2(u1)
        u2 = torch.cat([u2, d6], dim=1)

        u3 = self.up3(u2)
        u3 = torch.cat([u3, d5], dim=1)

        u4 = self.up4(u3)
        u4 = torch.cat([u4, d4], dim=1)

        u5 = self.up5(u4)
        u5 = torch.cat([u5, d3], dim=1)

        u6 = self.up6(u5)
        u6 = torch.cat([u6, d2], dim=1)

        u7 = self.up7(u6)
        u7 = torch.cat([u7, d1], dim=1)

        output = self.final(u7)
        output = self.tanh(output)

        # Downsample a saída para o tamanho original da entrada
        output = F.interpolate(output, size=orig_size, mode='bilinear', align_corners=False)

        return output

# === PatchGAN Discriminator ===
class PatchDiscriminator(nn.Module):
    def __init__(self, input_channels=3, norm_type='batchnorm', target=True):
        super().__init__()
        self.target = target
        in_ch = input_channels * 2 if target else input_channels

        self.model = nn.Sequential(
            nn.Conv2d(in_ch, 64, kernel_size=4, stride=2, padding=1),  # no norm
            nn.LeakyReLU(0.2, inplace=True),

            downsample(64, 128, norm_type),
            downsample(128, 256, norm_type),

            nn.ZeroPad2d(1),
            nn.Conv2d(256, 512, kernel_size=4, stride=1, padding=0, bias=False),

            nn.BatchNorm2d(512) if norm_type == 'batchnorm' else InstanceNormalization(),
            nn.LeakyReLU(0.2, inplace=True),

            nn.ZeroPad2d(1),
            nn.Conv2d(512, 1, kernel_size=4, stride=1, padding=0)
        )

    def forward(self, inp, target=None):
        if self.target and target is not None:
            x = torch.cat([inp, target], dim=1)
        else:
            x = inp
        return self.model(x)


# Redução James-Stein
def james_stein_reduce(errors: torch.Tensor) -> torch.Tensor:
    mean = torch.mean(errors)
    var = torch.var(errors, unbiased=False)
    norm_sq = torch.sum((errors - mean) ** 2)
    dim = errors.numel()
    shrinkage = torch.clamp(1 - ((dim - 2) * var / (norm_sq + 1e-8)), min=0.0)
    return mean + shrinkage * (errors - mean).mean()

# Função de perda para o Discriminador
def discriminator_loss(disc_real_output, disc_fake_output):
    criterion = nn.BCEWithLogitsLoss()
    real_labels = torch.ones_like(disc_real_output)
    fake_labels = torch.zeros_like(disc_fake_output)
    loss_real = criterion(disc_real_output, real_labels)
    loss_fake = criterion(disc_fake_output, fake_labels)
    return loss_real + loss_fake

# Função de perda para o Gerador
def generator_loss(disc_fake_output, generated_image, target_image, mask, lambda_l1=100):
    criterion_GAN = nn.BCEWithLogitsLoss()

    # Perda GAN
    real_labels = torch.ones_like(disc_fake_output )
    loss_GAN = criterion_GAN(disc_fake_output , real_labels)

    # Perda L1 com James-Stein
    error = (generated_image - target_image) * mask  # erro mascarado
    error_flat = error.view(-1)                      # achatar para aplicar James-Stein
    loss_L1 = james_stein_reduce(torch.abs(error_flat)) * lambda_l1

    loss_gen = loss_GAN + loss_L1
    return loss_gen, loss_GAN, loss_L1


def train_step(input_image, target_image, generator, discriminator, generator_optimizer, discriminator_optimizer, device):
    """
    Executa um passo de treinamento para o Pix2Pix.
    
    :param input_image: Imagem de entrada (tensor) - shape [B, C_in, H, W]
    :param target_image: Imagem alvo (tensor) - shape [B, C_out, H, W]
    :param generator: Modelo Gerador
    :param discriminator: Modelo Discriminador
    :param generator_optimizer: Otimizador para o Gerador
    :param discriminator_optimizer: Otimizador para o Discriminador
    :param device: Dispositivo (CPU ou CUDA)
    :return: Perdas do gerador e do discriminador
    """
    LAMBDA = 100
    generator.train()
    discriminator.train()
   
    # Zera os gradientes dos otimizadores
    generator_optimizer.zero_grad()
    discriminator_optimizer.zero_grad()
    
    with torch.amp.autocast(device_type='cuda', enabled=True):  # Mixed precision training
        # Gera a imagem
        generated_image = generator(input_image)
        
        # Concatenar input e imagens para o discriminador
        real_combined = torch.cat([input_image, target_image], dim=1)  # [B, C_in + C_out, H, W]
        fake_combined = torch.cat([input_image, generated_image.detach()], dim=1)  # [B, C_in + C_out, H, W]
        
        # Passar pelo Discriminador
        disc_real_output = discriminator(real_combined)
        disc_fake_output = discriminator(fake_combined)
        
        # Calcular a perda do Discriminador
        loss_D = discriminator_loss(disc_real_output, disc_fake_output)
        
        # Calcular a perda do Gerador
        fake_combined_for_gen = torch.cat([input_image, generated_image], dim=1)
        disc_fake_output_for_gen = discriminator( )
        loss_G, loss_GAN, loss_L1 = generator_loss(disc_fake_output_for_gen, generated_image, target_image, lambda_l1=LAMBDA)
    
    # Atualizar Discriminador
    loss_D.backward()  # Retain graph para que os gradientes do Gerador não sejam perdidos
    discriminator_optimizer.step()
    
    # Atualizar Gerador
    loss_G.backward()
    generator_optimizer.step()
    
    return loss_G.item(), loss_D.item(), loss_GAN.item(), loss_L1.item()


def validate_step(input_image, target_image, generator, discriminator, device):
    """
    Executa um passo de validação para o Pix2Pix.
    
    :param input_image: Imagem de entrada (tensor) - shape [B, C_in, H, W]
    :param target_image: Imagem alvo (tensor) - shape [B, C_out, H, W]
    :param generator: Modelo Gerador
    :param discriminator: Modelo Discriminador
    :param device: Dispositivo (CPU ou CUDA)
    :return: Perdas do gerador e do discriminador
    """
    LAMBDA = 100
    generator.eval()
    discriminator.eval()
    
    with torch.no_grad(), torch.amp.autocast(device_type='cuda', enabled=True):  # Mixed precision validation
        # Gera a imagem
        generated_image = generator(input_image)
        
        # Concatenar input e imagens para o discriminador
        real_combined = torch.cat([input_image, target_image], dim=1)  # [B, C_in + C_out, H, W]
        fake_combined = torch.cat([input_image, generated_image], dim=1)  # [B, C_in + C_out, H, W]
        
        # Passar pelo Discriminador
        disc_real_output = discriminator(real_combined)
        disc_fake_output = discriminator(fake_combined)
        
        # Calcular a perda do Discriminador
        loss_D = discriminator_loss(disc_real_output, disc_fake_output)
        
        # Calcular a perda do Gerador
        disc_fake_output_for_gen = discriminator(fake_combined)
        loss_G, loss_GAN, loss_L1 = generator_loss(disc_fake_output_for_gen, generated_image, target_image, lambda_l1=LAMBDA)
    
    return loss_G.item(), loss_D.item(), loss_GAN.item(), loss_L1.item()

def weights_init(m):
        if isinstance(m, nn.Conv2d) or isinstance(m, nn.ConvTranspose2d):
            if m.weight is not None:
                nn.init.normal_(m.weight.data, 0.0, 0.02)
            if m.bias is not None:
                nn.init.constant_(m.bias.data, 0)
        elif isinstance(m, nn.BatchNorm2d) or isinstance(m, nn.InstanceNorm2d):
            if m.weight is not None:
                nn.init.normal_(m.weight.data, 1.0, 0.02)
            if m.bias is not None:
                nn.init.constant_(m.bias.data, 0)



def add_masked_gaussian_noise(x: torch.Tensor, mask: torch.Tensor, sigma: float) -> torch.Tensor:
    # x: [B,C,H,W], mask: [B,1,H,W]
    inv = (1.0 - mask).expand(-1, x.size(1), -1, -1)
    noise = torch.randn_like(x) * sigma
    return x + inv * noise



def james_stein_reduce(errors: torch.Tensor) -> torch.Tensor:
    mean = torch.mean(errors)
    var = torch.var(errors, unbiased=False)
    norm_sq = torch.sum((errors - mean) ** 2)
    dim = errors.numel()
    shrinkage = torch.clamp(1 - ((dim - 2) * var / (norm_sq + 1e-8)), min=0.0)
    # reduz o vetor inteiro para escalar ajustado
    return mean + shrinkage * (errors - mean).mean()


def train_and_validate(epochs, train_loader, val_loader, in_channels, taxa,fold,patience=100):
		DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
		save_dir = f"./models_saved/pix2pix/{in_channels}/{taxa}/fold{fold}"
		# Instanciar Gerador e Discriminador
		generator = UNetGenerator(input_channels=in_channels, output_channels=in_channels).to(DEVICE)
		discriminator = PatchDiscriminator( input_channels=in_channels).to(DEVICE)
		
	# Multi-GPU (se disponível)
		if torch.cuda.device_count() > 1:
			print(f"Usando {torch.cuda.device_count()} GPUs")
			generator = torch.nn.DataParallel(generator)
			discriminator = torch.nn.DataParallel(discriminator)

		# Inicializar Pesos
		generator.apply(weights_init)
		discriminator.apply(weights_init)

		# Configurar Otimizadores
		generator_optimizer = optim.Adam(generator.parameters(), lr=1e-4, betas=(0.55, 0.999))
		discriminator_optimizer = optim.Adam(discriminator.parameters(), lr=1e-5, betas=(0.55, 0.999))
		best_val_loss = float('inf')
		output_buffer = "" 
                
		for epoch in range(epochs):
				generator.train()
				discriminator.train()
				initial_sigma = 0.1
				final_sigma = 0.0
				warmup_epochs = 100
				sigma = max(0.0, initial_sigma * (1 - epoch / warmup_epochs))

				for input_image, target_image,_,_, masks in train_loader:
						# print(input_image.shape,target_image.shape)
						input_image = input_image.to(DEVICE).to(torch.float32)
						target_image = target_image.to(DEVICE).to(torch.float32)
						masks = masks.to(DEVICE)
						masks = masks.view(-1, 1, 32, 32).float()   # [batch_size, 1, 32, 32]
						input_image = add_masked_gaussian_noise(input_image,masks,sigma=sigma)
						target_image = add_masked_gaussian_noise(target_image,masks,sigma=sigma)

						# Zerar Gradientes
						generator_optimizer.zero_grad()
						discriminator_optimizer.zero_grad()

						# Forward Pass
						generated_image = generator(input_image)

						# Discriminador
						real_combined = torch.cat([input_image, target_image], dim=1)
						fake_combined = torch.cat([input_image, generated_image.detach()], dim=1)
						disc_real_output = discriminator(real_combined)
						disc_fake_output = discriminator(fake_combined)

						# Perdas
						loss_D = discriminator_loss(disc_real_output, disc_fake_output)
						disc_fake_output_for_gen = discriminator(torch.cat([input_image, generated_image], dim=1))
						loss_G, loss_GAN, loss_L1 = generator_loss(disc_fake_output_for_gen, generated_image, target_image,masks, lambda_l1=100)

						# Soma das perdas
						total_loss = loss_D + loss_G

						# Backward
						total_loss.backward()

						# Atualizar Pesos
						discriminator_optimizer.step()
						generator_optimizer.step()
						
				if epoch % 10 == 0 :
						# Validação
						generator.eval()
						discriminator.eval()

						val_gen_loss = 0
						val_disc_loss = 0
						val_gan_loss = 0
						val_l1_loss = 0
						num_batches = 0

						with torch.no_grad():
								for input_image, target_image,_,_,masksv in val_loader:
										input_image = input_image.to(DEVICE).to(torch.float32)
										target_image = target_image.to(DEVICE).to(torch.float32)
										masksv = masksv.to(DEVICE)
										masksv = masksv.view(-1, 1, 32, 32).float()   # [batch_size, 1, 32, 32]
										input_image = add_masked_gaussian_noise(input_image, masksv, sigma=sigma)
										target_image = add_masked_gaussian_noise(input_image, masksv, sigma=sigma)
																																			


										generated_image = generator(input_image)
										real_combined = torch.cat([input_image, target_image], dim=1)
										fake_combined = torch.cat([input_image, generated_image], dim=1)

										disc_real_output = discriminator(real_combined)
										disc_fake_output = discriminator(fake_combined)
										loss_D = discriminator_loss(disc_real_output, disc_fake_output)

										disc_fake_output_for_gen = discriminator(torch.cat([input_image, generated_image], dim=1))
										loss_G, loss_GAN, loss_L1 = generator_loss(disc_fake_output_for_gen, generated_image, target_image,masksv, lambda_l1=100)

										val_gen_loss += loss_G.item()
										val_disc_loss += loss_D.item()
										val_gan_loss += loss_GAN.item()
										val_l1_loss += loss_L1.item()
										num_batches += 1

						val_gen_loss /= num_batches
						val_disc_loss /= num_batches
						val_gan_loss /= num_batches
						val_l1_loss /= num_batches
						if epoch % 50==0:
		
								output_buffer += f'Epoch: {epoch}, Gen Loss: {val_gen_loss:.4f}, Disc Loss: {val_disc_loss:.4f}, GAN Loss: {val_gan_loss:.4f}, L1 Loss: {val_l1_loss:.4f}\n'
								# Usa sys.stdout.write para imprimir sem adicionar uma nova linha
								sys.stdout.write(output_buffer)
								# Limpa o buffer para não acumular demais
								output_buffer = ""
						# Criar diretório se não existir
						if val_gen_loss < best_val_loss:
										best_val_loss = val_gen_loss
										epochs_no_improve = 0  # Reset contador de early stopping
						else:
								epochs_no_improve += 1

						if epochs_no_improve >= patience:
								print("Early stopping ativado!")
								break  # Sai do loop de épocas, mas deve continuar para próximo fold/t	
                                                	
		if not os.path.exists(save_dir):
			os.makedirs(save_dir)
						
			# Salvar modelo
			model_path = os.path.join(save_dir, "generator.pth")
			torch.save(generator.state_dict(), model_path)
			print(f"Modelo salvo em: {model_path}") 
			
		return generator


def train_and_validate1(epochs, train_loader, val_loader, in_channels, taxa, fold, patience=50):
	DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
	save_dir =  f"./models_saved/pix2pix/{in_channels}/{taxa}/fold{fold}"

	# Instanciar Gerador e Discriminador
	generator = UNetGenerator(input_channels=in_channels, output_channels=in_channels)
	discriminator = PatchDiscriminator(input_channels=in_channels)

	# Multi-GPU (se disponível)
	# if torch.cuda.device_count() > 1:
	# 	print(f"Usando {torch.cuda.device_count()} GPUs")
	# 	generator = torch.nn.DataParallel(generator)
	# 	discriminator = torch.nn.DataParallel(discriminator)

	generator = generator.to(DEVICE)
	discriminator = discriminator.to(DEVICE)

	# Inicializar Pesos
	generator.apply(weights_init)
	discriminator.apply(weights_init)

	# Configurar Otimizadores
	generator_optimizer = optim.Adam(generator.parameters(), lr=1e-4, betas=(0.55, 0.999))
	discriminator_optimizer = optim.Adam(discriminator.parameters(), lr=1e-5, betas=(0.55, 0.999))

	best_val_loss = float('inf')
	epochs_no_improve = 0
	output_buffer = ""

	for epoch in range(epochs):
		generator.train()
		discriminator.train()

		initial_sigma = 0.1
		final_sigma = 0.0
		warmup_epochs = 50
		sigma = max(0.0, initial_sigma * (1 - epoch / warmup_epochs))

		for input_image, target_image, _, _, masks in train_loader:
			input_image = input_image.to(DEVICE).float()
			target_image = target_image.to(DEVICE).float()
			masks = masks.to(DEVICE).view(-1, 1, 32, 32).float()

			input_image = add_masked_gaussian_noise(input_image, masks, sigma=sigma)
			target_image = add_masked_gaussian_noise(target_image, masks, sigma=sigma)

			generator_optimizer.zero_grad()
			discriminator_optimizer.zero_grad()

			generated_image = generator(input_image)

			real_combined = torch.cat([input_image, target_image], dim=1)
			fake_combined = torch.cat([input_image, generated_image.detach()], dim=1)

			disc_real_output = discriminator(real_combined)
			disc_fake_output = discriminator(fake_combined)

			loss_D = discriminator_loss(disc_real_output, disc_fake_output)

			disc_fake_output_for_gen = discriminator(torch.cat([input_image, generated_image], dim=1))
			loss_G, loss_GAN, loss_L1 = generator_loss(disc_fake_output_for_gen, generated_image, target_image, masks, lambda_l1=100)

			total_loss = loss_D + loss_G
			total_loss.backward()
			discriminator_optimizer.step()
			generator_optimizer.step()

		# Validação
		if epoch % 10 == 0:
			generator.eval()
			discriminator.eval()

			val_gen_loss = 0
			val_disc_loss = 0
			val_gan_loss = 0
			val_l1_loss = 0
			num_batches = 0

			with torch.no_grad():
				for input_image, target_image, _, _, masksv in val_loader:
					input_image = input_image.to(DEVICE).float()
					target_image = target_image.to(DEVICE).float()
					masksv = masksv.to(DEVICE).view(-1, 1, 32, 32).float()

					input_image = add_masked_gaussian_noise(input_image, masksv, sigma=sigma)
					target_image = add_masked_gaussian_noise(target_image, masksv, sigma=sigma)

					generated_image = generator(input_image)

					real_combined = torch.cat([input_image, target_image], dim=1)
					fake_combined = torch.cat([input_image, generated_image], dim=1)

					disc_real_output = discriminator(real_combined)
					disc_fake_output = discriminator(fake_combined)

					loss_D = discriminator_loss(disc_real_output, disc_fake_output)
					disc_fake_output_for_gen = discriminator(torch.cat([input_image, generated_image], dim=1))
					loss_G, loss_GAN, loss_L1 = generator_loss(disc_fake_output_for_gen, generated_image, target_image, masksv, lambda_l1=100)

					val_gen_loss += loss_G.item()
					val_disc_loss += loss_D.item()
					val_gan_loss += loss_GAN.item()
					val_l1_loss += loss_L1.item()
					num_batches += 1

			val_gen_loss /= num_batches
			val_disc_loss /= num_batches
			val_gan_loss /= num_batches
			val_l1_loss /= num_batches

			if epoch % 50 == 0:
				output_buffer += f'Epoch: {epoch}, Gen Loss: {val_gen_loss:.4f}, Disc Loss: {val_disc_loss:.4f}, GAN Loss: {val_gan_loss:.4f}, L1 Loss: {val_l1_loss:.4f}\n'
				sys.stdout.write(output_buffer)
				output_buffer = ""

			# Early Stopping
			if val_gen_loss < best_val_loss:
				best_val_loss = val_gen_loss
				epochs_no_improve = 0
			else:
				epochs_no_improve += 1

			if  val_gen_loss > best_val_loss and epochs_no_improve >= patience:
				print("Early stopping ativado!")
				break

	# Salvar modelo
	if not os.path.exists(save_dir):
		os.makedirs(save_dir)

	# Remover wrapper do DataParallel antes de salvar
	if isinstance(generator, torch.nn.DataParallel):
		torch.save(generator.module.state_dict(), os.path.join(save_dir, "generator.pth"))
	else:
		torch.save(generator.state_dict(), os.path.join(save_dir, "generator.pth"))

	print(f"Modelo salvo em: {save_dir}/generator.pth")
	return generator
              
              

# Função de teste
def test(gen_Z, test_loader, taxa, fold, chanells):
    DEVICE = "cuda:0" if torch.cuda.is_available() else "cpu"

    gen_Z.eval()

    with torch.no_grad():
        # Criar o DataFrame com as colunas desejadas
        df = pd.DataFrame([], columns=['mae', 'asmape', 'mape', 'rmse', 'scale'], index=test_loader.dataset.horse_images)

        for (zebra, horse, std_val, mean_val, mask), name in zip(test_loader, test_loader.dataset.horse_images):

            # Mover dados para o dispositivo
            zebra = zebra.to(DEVICE)
            horse = horse.to(DEVICE)

            # Converter std_val e mean_val para tensores e movê-los para o dispositivo
            std_val = torch.tensor(std_val, device=DEVICE) if not isinstance(std_val, torch.Tensor) else std_val.to(DEVICE)
            mean_val = torch.tensor(mean_val, device=DEVICE) if not isinstance(mean_val, torch.Tensor) else mean_val.to(DEVICE)

            # Gerar fake_zebra usando o gerador
            fake_zebra = gen_Z(horse)

            # Mover apenas as imagens para a CPU antes de operações subsequentes
            zebra = zebra.cpu()
            fake_zebra = fake_zebra.cpu()

            # Voltar para escala original 
            zebra = zebra * std_val.cpu() + mean_val.cpu()
            fake_zebra = fake_zebra * std_val.cpu() + mean_val.cpu()

            # Somar sobre o canal e achatar as imagens
            zebra = torch.sum(zebra, dim=1).flatten()
            fake_zebra = torch.sum(fake_zebra, dim=1).flatten()

            # Certificar que zebra e fake_zebra estão na CPU e sem gradientes antes de usar métricas
            zebra_np = zebra * mask
            fake_zebra_np = fake_zebra * mask

            # Calcular as métricas corretamente
            mae_value = round(mae(zebra_np, fake_zebra_np), 3)
            mape_value = round(mape(zebra_np, fake_zebra_np) * 100, 3)
            rmse_value = round(np.sqrt(mse(zebra_np, fake_zebra_np)), 3)
            smape_value = round(asmape(zebra_np, fake_zebra_np, mask), 3)

            # Adicionar os resultados ao DataFrame
            df.loc[name] = [mae_value, smape_value, mape_value, rmse_value, np.max(zebra.numpy()) - np.min(zebra.numpy())]

        # Salvar o DataFrame em um arquivo CSV
        directory = "./resultados/resultados_pix"
        if not os.path.exists(directory):
            os.makedirs(directory)

        df.to_csv(os.path.join(directory, f'result_{str(chanells)}c_{taxa}_{fold}.csv'))


# Parâmetros de treinamento

BATCH_SIZE = 400
NUM_EPOCHS = 50000


In [3]:
import os

TRAIN_DIR = os.path.abspath("../dataset_final")  
VAL_DIR = os.path.abspath("../dataset_final")  
INDEX_TRAIN = os.path.abspath("../dataset_final")  
INDEX_VAL = os.path.abspath("../dataset_final")  
INDEX_TEST = os.path.abspath("../dataset_final")  


In [4]:

def main(in_channels):
    for taxa in ['10','20','30','40']:
        for fold in ['1','2','3','4','5']:  
            dataset = LoaderDataset(
                  root_zebra=os.path.join( TRAIN_DIR, "label", str(taxa), "folds", f"fold{fold}", "train"),
                  root_horse=os.path.join( TRAIN_DIR, "input", str(taxa), "folds", f"fold{fold}", "train"),
									root_masks=os.path.join( INDEX_TRAIN, "input", str(taxa), "folds", f"fold{fold}", "index_train"),
									chanels=in_channels
							  )
            
            val_dataset = LoaderDataset(
                root_zebra=os.path.join( VAL_DIR, "label", str(taxa), "folds", f"fold{fold}", "val"),
								root_horse=os.path.join( VAL_DIR, "input", str(taxa), "folds", f"fold{fold}", "val"),
								root_masks=os.path.join( INDEX_VAL, "input", str(taxa), "folds", f"fold{fold}", "index_val"),
								chanels=in_channels
						  )

            test_dataset = LoaderDataset(
                  root_zebra=os.path.join( VAL_DIR, "label", str(taxa), "folds", f"fold{fold}", "test"),
									root_horse=os.path.join( VAL_DIR, "input", str(taxa), "folds", f"fold{fold}", "test"),
									root_masks=os.path.join(INDEX_TEST, "input", str(taxa), "folds", f"fold{fold}", "index"),
									chanels=in_channels
						)
            			
            val_loader = DataLoader(val_dataset, batch_size=BATCH_SIZE, shuffle=False, pin_memory=False,)
            loader = DataLoader(dataset, batch_size=BATCH_SIZE,  shuffle=True, pin_memory=False,    )
            
            test_loader = DataLoader( test_dataset,  batch_size=1,shuffle=False,pin_memory=False )
            
            test_loader = DataLoader(test_dataset, batch_size=1, shuffle=False, pin_memory=True)
            
            # Treino        
            generator = train_and_validate1(NUM_EPOCHS, loader, val_loader, in_channels,taxa, fold)

            # Teste
            test(generator, test_loader=test_loader, taxa=taxa, fold=fold, chanells=in_channels)
            
import multiprocessing as mp
if __name__ == '__main__':
    mp.freeze_support()
    for i in [1,2,3]:
        print(f'channel {i}')
        main(i)

channel 1
Usando 2 GPUs


OutOfMemoryError: Caught OutOfMemoryError in replica 1 on device 1.
Original Traceback (most recent call last):
  File "/home/mauricio/.conda/envs/pyt_envmau/lib/python3.10/site-packages/torch/nn/parallel/parallel_apply.py", line 84, in _worker
    output = module(*input, **kwargs)
  File "/home/mauricio/.conda/envs/pyt_envmau/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1553, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
  File "/home/mauricio/.conda/envs/pyt_envmau/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1562, in _call_impl
    return forward_call(*args, **kwargs)
  File "/tmp/ipykernel_220986/2833226606.py", line 204, in forward
    u5 = self.up5(u4)
  File "/home/mauricio/.conda/envs/pyt_envmau/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1553, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
  File "/home/mauricio/.conda/envs/pyt_envmau/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1562, in _call_impl
    return forward_call(*args, **kwargs)
  File "/home/mauricio/.conda/envs/pyt_envmau/lib/python3.10/site-packages/torch/nn/modules/container.py", line 219, in forward
    input = module(input)
  File "/home/mauricio/.conda/envs/pyt_envmau/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1553, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
  File "/home/mauricio/.conda/envs/pyt_envmau/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1562, in _call_impl
    return forward_call(*args, **kwargs)
  File "/home/mauricio/.conda/envs/pyt_envmau/lib/python3.10/site-packages/torch/nn/modules/batchnorm.py", line 176, in forward
    return F.batch_norm(
  File "/home/mauricio/.conda/envs/pyt_envmau/lib/python3.10/site-packages/torch/nn/functional.py", line 2512, in batch_norm
    return torch.batch_norm(
torch.OutOfMemoryError: CUDA out of memory. Tried to allocate 200.00 MiB. GPU 1 has a total capacity of 23.64 GiB of which 127.06 MiB is free. Process 3538589 has 19.64 GiB memory in use. Including non-PyTorch memory, this process has 3.86 GiB memory in use. Of the allocated memory 3.19 GiB is allocated by PyTorch, and 126.68 MiB is reserved by PyTorch but unallocated. If reserved but unallocated memory is large try setting PYTORCH_CUDA_ALLOC_CONF=expandable_segments:True to avoid fragmentation.  See documentation for Memory Management  (https://pytorch.org/docs/stable/notes/cuda.html#environment-variables)
