In [6]:
import os
import numpy as np
import itertools
import datetime
import time
# import torchvision.transforms as transforms
import torch.nn as nn
import torch
# from torchvision.utils import save_image
import sys 
import warnings
warnings.filterwarnings("ignore", category=DeprecationWarning)
from torch.utils.data import Dataset
from torch.utils.data import DataLoader


def asmape(y_true, y_pred, mask=None):
    if mask is not None:
         y_true, y_pred = y_true[mask==1], y_pred[mask==1]
    if type(y_true) is list or type(y_pred) is list:
         y_true, y_pred = np.array(y_true), np.array(y_pred)
    len_ = len(y_true)
    tmp = 100 * (np.nansum(np.abs(y_pred - y_true) / (np.abs(y_true) + np.abs(y_pred)))/len_)

    return tmp



import pandas as pd 
from sklearn.metrics import mean_absolute_percentage_error as mape, mean_absolute_error as mae, mean_squared_error as mse




import torch.nn as nn
import torch.nn.functional as F
import torch


def weights_init_normal(m):
    classname = m.__class__.__name__
    if classname.find("Conv") != -1:
        torch.nn.init.normal_(m.weight.data, 0.0, 0.02)
    elif classname.find("BatchNorm2d") != -1:
        torch.nn.init.normal_(m.weight.data, 1.0, 0.02)
        torch.nn.init.constant_(m.bias.data, 0.0)


##############################
#           U-NET
##############################

class LoaderDataset(Dataset):
    def __init__(self, root_zebra, root_horse, root_masks, chanels=3):
        self.root_zebra = root_zebra
        self.root_horse = root_horse
        self.root_index = root_masks
        
        self.zebra_images = sorted(os.listdir(root_zebra))
        self.horse_images = sorted(os.listdir(root_horse))
        self.index = sorted(os.listdir(root_masks))

        self.length_dataset = max(len(self.zebra_images), len(self.horse_images))
        self.zebra_len = len(self.zebra_images)
        self.horse_len = len(self.horse_images)
        self.index_len = len(self.index)
        self.chanels = chanels

    def __len__(self):
        return self.length_dataset

    @staticmethod
    def custom_normalize(image):
        image = torch.tensor(image, dtype=torch.float32)
        min_val = torch.min(image)
        max_val = torch.max(image)
        scale = torch.clamp(max_val - min_val, min=1e-5)  # Evita divisão por zero
        image_normalized = 2 * (image - min_val) / scale - 1  # Escala para [-1, 1]
        return image_normalized, min_val, max_val

    def __getitem__(self, index):
        zebra_img = self.zebra_images[index % self.zebra_len]
        horse_img = self.horse_images[index % self.horse_len]
        index_ids = self.index[index % self.index_len]

        zebra_path = os.path.join(self.root_zebra, zebra_img)
        horse_path = os.path.join(self.root_horse, horse_img)
        index_path = os.path.join(self.root_index, index_ids)
        # print(zebra_path, horse_path, index_path)

        zebra_img = np.load(zebra_path)
        horse_img = np.load(horse_path)
        mask = np.load(index_path)

        if len(zebra_img.shape) > 3:
            zebra_img = zebra_img.reshape(32, 32, 3)
            horse_img = horse_img.reshape(32, 32, 3)

        zebra_img = np.transpose(zebra_img, (2, 0, 1))
        horse_img = np.transpose(horse_img, (2, 0, 1))

        if self.chanels == 2:
            zebra_img = zebra_img[:2, :, :]
            horse_img = horse_img[:2, :, :]
        elif self.chanels == 1:
            zebra_img = np.sum(zebra_img, axis=0, keepdims=True)
            horse_img = np.sum(horse_img, axis=0, keepdims=True)

        zebra_img, min_val_z, max_val_z = LoaderDataset.custom_normalize(zebra_img)
        horse_img, _, _ = LoaderDataset.custom_normalize(horse_img)

        mask = torch.tensor(mask, dtype=torch.float32)

        return zebra_img, horse_img, min_val_z, max_val_z, mask


class ToFloat32:
    def __call__(self, image, **kwargs):
        return image.float()






import torch
import torch.nn as nn
import torch.nn.functional as F

# === UNet Generator ===

# === Instance Normalization Custom (como no TF) ===
class InstanceNormalization(nn.Module):
    def __init__(self, epsilon=1e-5):
        super().__init__()
        self.epsilon = epsilon
        
        # escala e offset serão inicializados no forward com parâmetros registrados

    def forward(self, x):
        
        # x shape: (N,C,H,W)
        mean = x.mean(dim=[2,3], keepdim=True)
        var = x.var(dim=[2,3], keepdim=True, unbiased=False)
        inv = 1.0 / torch.sqrt(var + self.epsilon)
        normalized = (x - mean) * inv

        # Criar escala e offset param se não existirem
        if not hasattr(self, 'scale'):
            self.scale = nn.Parameter(torch.ones(x.size(1), device=x.device))
            self.offset = nn.Parameter(torch.zeros(x.size(1), device=x.device))
        # reshape para broadcast
        scale = self.scale.view(1, -1, 1, 1)
        offset = self.offset.view(1, -1, 1, 1)
        return scale * normalized + offset



# === Downsample e Upsample ===
def downsample(in_ch, out_ch, norm_type='instancenorm', apply_norm=True):
    layers = [nn.Conv2d(in_ch, out_ch, kernel_size=4, stride=2, padding=1, bias=False)]
    if apply_norm:
        if norm_type == 'batchnorm':
            layers.append(nn.BatchNorm2d(out_ch))
        elif norm_type == 'instancenorm':
            layers.append(InstanceNormalization())
    layers.append(nn.LeakyReLU(0.2, inplace=True))
    return nn.Sequential(*layers)

def upsample(in_ch, out_ch, norm_type='instancenorm', apply_dropout=False):
    layers = [nn.ConvTranspose2d(in_ch, out_ch, kernel_size=4, stride=2, padding=1, bias=False)]
    if norm_type == 'batchnorm':
        layers.append(nn.BatchNorm2d(out_ch))
    elif norm_type == 'instancenorm':
        layers.append(InstanceNormalization())
    layers.append(nn.ReLU(inplace=True))
    if apply_dropout:
        layers.append(nn.Dropout(0.5))
    return nn.Sequential(*layers)


class UNetGenerator(nn.Module):
    def __init__(self, input_channels=3, output_channels=3, norm_type='instancenorm', target_size=256):
        super().__init__()
        self.target_size = target_size
        self.down1 = downsample(input_channels, 64, norm_type, apply_norm=False)
        self.down2 = downsample(64, 128, norm_type)
        self.down3 = downsample(128, 256, norm_type)
        self.down4 = downsample(256, 512, norm_type)
        self.down5 = downsample(512, 512, norm_type)
        self.down6 = downsample(512, 512, norm_type)
        self.down7 = downsample(512, 512, norm_type)
        self.down8 = downsample(512, 512, norm_type)

        self.up1 = upsample(512, 512, norm_type, apply_dropout=True)
        self.up2 = upsample(1024, 512, norm_type, apply_dropout=True)
        self.up3 = upsample(1024, 512, norm_type, apply_dropout=True)
        self.up4 = upsample(1024, 512, norm_type)
        self.up5 = upsample(1024, 256, norm_type)
        self.up6 = upsample(512, 128, norm_type)
        self.up7 = upsample(256, 64, norm_type)

        self.final = nn.ConvTranspose2d(128, output_channels, kernel_size=4, stride=2, padding=1)
        self.tanh = nn.Tanh()

    def forward(self, x):
        orig_size = x.shape[-2:]  # (H, W)

        # Upsample entrada para target_size x target_size
        x = F.interpolate(x, size=(self.target_size, self.target_size), mode='bilinear', align_corners=False)

        d1 = self.down1(x)
        d2 = self.down2(d1)
        d3 = self.down3(d2)
        d4 = self.down4(d3)
        d5 = self.down5(d4)
        d6 = self.down6(d5)
        d7 = self.down7(d6)
        d8 = self.down8(d7)

        u1 = self.up1(d8)
        u1 = torch.cat([u1, d7], dim=1)

        u2 = self.up2(u1)
        u2 = torch.cat([u2, d6], dim=1)

        u3 = self.up3(u2)
        u3 = torch.cat([u3, d5], dim=1)

        u4 = self.up4(u3)
        u4 = torch.cat([u4, d4], dim=1)

        u5 = self.up5(u4)
        u5 = torch.cat([u5, d3], dim=1)

        u6 = self.up6(u5)
        u6 = torch.cat([u6, d2], dim=1)

        u7 = self.up7(u6)
        u7 = torch.cat([u7, d1], dim=1)

        output = self.final(u7)
        output = self.tanh(output)

        # Downsample a saída para o tamanho original da entrada
        output = F.interpolate(output, size=orig_size, mode='bilinear', align_corners=False)

        return output



def test(gen_Z, test_loader, taxa, fold, chanells):
		DEVICE = "cuda:0" if torch.cuda.is_available() else "cpu"

		gen_Z.eval()
		
		with torch.no_grad():
			# Criar o DataFrame com as colunas desejadas
			df = pd.DataFrame([], columns=['mae', 'asmape', 'mape', 'rmse', 'scale'], index=test_loader.dataset.horse_images)

			for (zebra, horse, std_val, mean_val, masks), name in zip(test_loader, test_loader.dataset.horse_images):
					# Verificar as dimensões das entradas
					
					# Mover dados para o dispositivo
					zebra = zebra.to(DEVICE)
					horse = horse.to(DEVICE)

					# Converter std_val e mean_val para tensores e movê-los para o dispositivo
					std_val = torch.tensor(std_val, device=DEVICE) if not isinstance(std_val, torch.Tensor) else std_val.to(DEVICE)
					mean_val = torch.tensor(mean_val, device=DEVICE) if not isinstance(mean_val, torch.Tensor) else mean_val.to(DEVICE)

					# Gerar fake_zebra usando o gerador
					fake_zebra = gen_Z(horse)

					# Mover apenas as imagens para a CPU antes de operações subsequentes
					zebra = zebra.cpu()
					fake_zebra = fake_zebra.cpu()

					# Voltar para escala original 
					zebra = zebra * std_val.cpu() + mean_val.cpu()
					fake_zebra = fake_zebra * std_val.cpu() + mean_val.cpu()

					# Somar sobre o canal e achatar as imagens
					zebra = torch.sum(zebra, dim=1).flatten()*masks
					fake_zebra = torch.sum(fake_zebra, dim=1).flatten()*masks

					# Calcular as métricas
					mae_value = round(mae(zebra, fake_zebra), 3)
					mape_value = round(mape(zebra, fake_zebra) * 100, 3)
					rmse_value = round(np.sqrt(mse(zebra, fake_zebra)), 3)
					smape_value = round(asmape(zebra, fake_zebra, masks), 3)

					# Adicionar os resultados ao DataFrame
					df.loc[name] = [mae_value,smape_value , mape_value, rmse_value, np.max(zebra.numpy()) - np.min(zebra.numpy())]

			# Salvar o DataFrame em um arquivo CSV
			directory = "./resultados/resultados_discogan"
			if not os.path.exists(directory):
					os.makedirs(directory)

			df.to_csv(os.path.join(directory, f'result_{str(chanells)}c_{taxa}_{fold}.csv'))



def load_checkpoint(checkpoint_file, model, DEVICE):
    # print("=> Loading checkpoint")
   state_dict = torch.load(checkpoint_file, map_location=DEVICE, weights_only=True)
   model.load_state_dict(state_dict)
   return model

# ----------
#  Treinamento
# ----------
# Parâmetros de treinamento

TRAIN_DIR = os.path.abspath("../dataset_final")  
VAL_DIR = os.path.abspath("../dataset_final")  
INDEX_TRAIN = os.path.abspath("../dataset_final")  
INDEX_VAL = os.path.abspath("../dataset_final")  
INDEX_TEST = os.path.abspath("../dataset_final")    
BATCH_SIZE = 512
NUM_EPOCHS = 1000

lrd= 1e-4
lrg= 1e-3
b1 =0.5
b2= 0.999
               
def main(in_channels):
	for taxa in ['10','20','30','40']:
			for fold in ['1','2','3','4','5']:
					# Perdas
					adversarial_loss = torch.nn.MSELoss()
					cycle_loss = torch.nn.L1Loss()
					pixelwise_loss = torch.nn.L1Loss()

					cuda =  "cuda:0" if torch.cuda.is_available() else "cpu" #torch.cuda.is_available()

					input_shape = (in_channels, 32, 32)

					# Inicializar geradores e discriminadores
			
					G_BA = UNetGenerator(input_channels=in_channels, output_channels=in_channels, norm_type='instancenorm')
					G_BA.to(cuda)
					dummy_in = torch.zeros(1, in_channels, G_BA.target_size, G_BA.target_size, device=cuda)
					_ = G_BA(dummy_in)
					
					
		
	
					test_dataset = LoaderDataset(
								root_zebra=os.path.join( VAL_DIR, "label", str(taxa), "folds", f"fold{fold}", "test"),
								root_horse=os.path.join( VAL_DIR, "input", str(taxa), "folds", f"fold{fold}", "test"),
								root_masks=os.path.join(INDEX_TEST, "input", str(taxa), "folds", f"fold{fold}", "index"),
								chanels=in_channels
					)
								
				
					test_loader = DataLoader( test_dataset,  batch_size=1,shuffle=False,pin_memory=False )
					
					test_loader = DataLoader(test_dataset, batch_size=1, shuffle=False, pin_memory=True)
					prev_time = time.time()

					best_val_loss = float('inf')

					

					save_dir = f"./models_saved/discogan/{in_channels}/{taxa}/fold{fold}"
					if not os.path.exists(save_dir):
						os.makedirs(save_dir)
							
					# Salvar modelo
					model_path = os.path.join(save_dir, "G_AB.pth")
					load_checkpoint(model_path,G_BA,cuda)
					print(f"\n Model {model_path}\n")
					# Teste após cada época
					test(G_BA, test_loader, taxa, fold, in_channels)
					G_BA.train()


if __name__ == '__main__':
    for i in [1,2,3]:
        main(i)


 Model ./models_saved/discogan/1/10/fold1/G_AB.pth


 Model ./models_saved/discogan/1/10/fold2/G_AB.pth


 Model ./models_saved/discogan/1/10/fold3/G_AB.pth


 Model ./models_saved/discogan/1/10/fold4/G_AB.pth


 Model ./models_saved/discogan/1/10/fold5/G_AB.pth


 Model ./models_saved/discogan/1/20/fold1/G_AB.pth


 Model ./models_saved/discogan/1/20/fold2/G_AB.pth


 Model ./models_saved/discogan/1/20/fold3/G_AB.pth


 Model ./models_saved/discogan/1/20/fold4/G_AB.pth


 Model ./models_saved/discogan/1/20/fold5/G_AB.pth


 Model ./models_saved/discogan/1/30/fold1/G_AB.pth


 Model ./models_saved/discogan/1/30/fold2/G_AB.pth


 Model ./models_saved/discogan/1/30/fold3/G_AB.pth


 Model ./models_saved/discogan/1/30/fold4/G_AB.pth


 Model ./models_saved/discogan/1/30/fold5/G_AB.pth


 Model ./models_saved/discogan/1/40/fold1/G_AB.pth


 Model ./models_saved/discogan/1/40/fold2/G_AB.pth


 Model ./models_saved/discogan/1/40/fold3/G_AB.pth


 Model ./models_saved/discogan/1/40/fold4/G_A