In [None]:
# -*- coding: utf-8 -*-
"""CycleGAN.ipynb

Automatically generated by Colab.

Original file is located at
"""

import torch
import torch.nn as nn
from torch.utils.data import DataLoader
from sklearn.metrics import mean_absolute_percentage_error as mape, mean_absolute_error as mae,mean_squared_error as mse
import sys
from torch.utils.data import Dataset
import random, torch, os, numpy as np
import torch.nn as nn
import copy
import torch.nn.functional as F

import torch.optim as optim
from tqdm import tqdm
import pandas as pd
from matplotlib.path import Path


import sys


# === UNet Generator ===


def asmape(y_true, y_pred, mask=None):
    if mask is not None:
         y_true, y_pred = y_true[mask==1], y_pred[mask==1]
    if type(y_true) is list or type(y_pred) is list:
         y_true, y_pred = np.array(y_true), np.array(y_pred)
    len_ = len(y_true)
    tmp = 100 * (np.nansum(np.abs(y_pred - y_true) / (np.abs(y_true) + np.abs(y_pred)))/len_)

    return tmp



class LoaderDataset(Dataset):
    def __init__(self, root_zebra, root_horse, root_masks, chanels=3):
        self.root_zebra = root_zebra
        self.root_horse = root_horse
        self.root_index = root_masks
        
        self.zebra_images = sorted(os.listdir(root_zebra))
        self.horse_images = sorted(os.listdir(root_horse))
        self.index = sorted(os.listdir(root_masks))

        self.length_dataset = max(len(self.zebra_images), len(self.horse_images))
        self.zebra_len = len(self.zebra_images)
        self.horse_len = len(self.horse_images)
        self.index_len = len(self.index)
        self.chanels = chanels

    def __len__(self):
        return self.length_dataset

    @staticmethod
    def custom_normalize(image):
        image = torch.tensor(image, dtype=torch.float32)
        min_val = torch.min(image)
        max_val = torch.max(image)
        scale = torch.clamp(max_val - min_val, min=1e-5)  # Evita divisão por zero
        image_normalized = 2 * (image - min_val) / scale - 1  # Escala para [-1, 1]
        return image_normalized, min_val, max_val

    def __getitem__(self, index):
        zebra_img = self.zebra_images[index % self.zebra_len]
        horse_img = self.horse_images[index % self.horse_len]
        index_ids = self.index[index % self.index_len]

        zebra_path = os.path.join(self.root_zebra, zebra_img)
        horse_path = os.path.join(self.root_horse, horse_img)
        index_path = os.path.join(self.root_index, index_ids)
        # print(zebra_path, horse_path, index_path)

        zebra_img = np.load(zebra_path)
        horse_img = np.load(horse_path)
        mask = np.load(index_path)

        if len(zebra_img.shape) > 3:
            zebra_img = zebra_img.reshape(32, 32, 3)
            horse_img = horse_img.reshape(32, 32, 3)

        zebra_img = np.transpose(zebra_img, (2, 0, 1))
        horse_img = np.transpose(horse_img, (2, 0, 1))

        if self.chanels == 2:
            zebra_img = zebra_img[:2, :, :]
            horse_img = horse_img[:2, :, :]
        elif self.chanels == 1:
            zebra_img = np.sum(zebra_img, axis=0, keepdims=True)
            horse_img = np.sum(horse_img, axis=0, keepdims=True)

        zebra_img, min_val_z, max_val_z = LoaderDataset.custom_normalize(zebra_img)
        horse_img, _, _ = LoaderDataset.custom_normalize(horse_img)

        mask = torch.tensor(mask, dtype=torch.float32)

        return zebra_img, horse_img, min_val_z, max_val_z, mask
    



# class ToFloat32:
#     def __call__(self, image, **kwargs):
#         return image.float()




# === Instance Normalization Custom (como no TF) ===
class InstanceNormalization(nn.Module):
    def __init__(self, epsilon=1e-5):
        super().__init__()
        self.epsilon = epsilon
        # escala e offset serão inicializados no forward com parâmetros registrados

    def forward(self, x):
        # x shape: (N,C,H,W)
        mean = x.mean(dim=[2,3], keepdim=True)
        var = x.var(dim=[2,3], keepdim=True, unbiased=False)
        inv = 1.0 / torch.sqrt(var + self.epsilon)
        normalized = (x - mean) * inv

        # Criar escala e offset param se não existirem
        if not hasattr(self, 'scale'):
            self.scale = nn.Parameter(torch.ones(x.size(1), device=x.device))
            self.offset = nn.Parameter(torch.zeros(x.size(1), device=x.device))
        # reshape para broadcast
        scale = self.scale.view(1, -1, 1, 1)
        offset = self.offset.view(1, -1, 1, 1)
        return scale * normalized + offset

# === Downsample e Upsample ===
def downsample(in_ch, out_ch, norm_type='batchnorm', apply_norm=True):
    layers = [nn.Conv2d(in_ch, out_ch, kernel_size=4, stride=2, padding=1, bias=False)]
    if apply_norm:
        if norm_type == 'batchnorm':
            layers.append(nn.BatchNorm2d(out_ch))
        elif norm_type == 'instancenorm':
            layers.append(InstanceNormalization())
    layers.append(nn.LeakyReLU(0.2, inplace=True))
    return nn.Sequential(*layers)

def upsample(in_ch, out_ch, norm_type='batchnorm', apply_dropout=False):
    layers = [nn.ConvTranspose2d(in_ch, out_ch, kernel_size=4, stride=2, padding=1, bias=False)]
    if norm_type == 'batchnorm':
        layers.append(nn.BatchNorm2d(out_ch))
    elif norm_type == 'instancenorm':
        layers.append(InstanceNormalization())
    layers.append(nn.ReLU(inplace=True))
    if apply_dropout:
        layers.append(nn.Dropout(0.5))
    return nn.Sequential(*layers)


class UNetGenerator(nn.Module):
    def __init__(self, input_channels=3, output_channels=3, norm_type='batchnorm', target_size=256):
        super().__init__()
        self.target_size = target_size
        self.down1 = downsample(input_channels, 64, norm_type, apply_norm=False)
        self.down2 = downsample(64, 128, norm_type)
        self.down3 = downsample(128, 256, norm_type)
        self.down4 = downsample(256, 512, norm_type)
        self.down5 = downsample(512, 512, norm_type)
        self.down6 = downsample(512, 512, norm_type)
        self.down7 = downsample(512, 512, norm_type)
        self.down8 = downsample(512, 512, norm_type)

        self.up1 = upsample(512, 512, norm_type, apply_dropout=True)
        self.up2 = upsample(1024, 512, norm_type, apply_dropout=True)
        self.up3 = upsample(1024, 512, norm_type, apply_dropout=True)
        self.up4 = upsample(1024, 512, norm_type)
        self.up5 = upsample(1024, 256, norm_type)
        self.up6 = upsample(512, 128, norm_type)
        self.up7 = upsample(256, 64, norm_type)

        self.final = nn.ConvTranspose2d(128, output_channels, kernel_size=4, stride=2, padding=1)
        self.tanh = nn.Tanh()

    def forward(self, x):
        orig_size = x.shape[-2:]  # (H, W)

        # Upsample entrada para target_size x target_size
        x = F.interpolate(x, size=(self.target_size, self.target_size), mode='bilinear', align_corners=False)

        d1 = self.down1(x)
        d2 = self.down2(d1)
        d3 = self.down3(d2)
        d4 = self.down4(d3)
        d5 = self.down5(d4)
        d6 = self.down6(d5)
        d7 = self.down7(d6)
        d8 = self.down8(d7)

        u1 = self.up1(d8)
        u1 = torch.cat([u1, d7], dim=1)

        u2 = self.up2(u1)
        u2 = torch.cat([u2, d6], dim=1)

        u3 = self.up3(u2)
        u3 = torch.cat([u3, d5], dim=1)

        u4 = self.up4(u3)
        u4 = torch.cat([u4, d4], dim=1)

        u5 = self.up5(u4)
        u5 = torch.cat([u5, d3], dim=1)

        u6 = self.up6(u5)
        u6 = torch.cat([u6, d2], dim=1)

        u7 = self.up7(u6)
        u7 = torch.cat([u7, d1], dim=1)

        output = self.final(u7)
        output = self.tanh(output)

        # Downsample a saída para o tamanho original da entrada
        output = F.interpolate(output, size=orig_size, mode='bilinear', align_corners=False)

        return output

# === PatchGAN Discriminator ===
class PatchDiscriminator(nn.Module):
    def __init__(self, input_channels=3, norm_type='batchnorm', target=True):
        super().__init__()
        self.target = target
        in_ch = input_channels * 2 if target else input_channels

        self.model = nn.Sequential(
            nn.Conv2d(in_ch, 64, kernel_size=4, stride=2, padding=1),  # no norm
            nn.LeakyReLU(0.2, inplace=True),

            downsample(64, 128, norm_type),
            downsample(128, 256, norm_type),

            nn.ZeroPad2d(1),
            nn.Conv2d(256, 512, kernel_size=4, stride=1, padding=0, bias=False),

            nn.BatchNorm2d(512) if norm_type == 'batchnorm' else InstanceNormalization(),
            nn.LeakyReLU(0.2, inplace=True),

            nn.ZeroPad2d(1),
            nn.Conv2d(512, 1, kernel_size=4, stride=1, padding=0)
        )

    def forward(self, inp, target=None):
        if self.target and target is not None:
            x = torch.cat([inp, target], dim=1)
        else:
            x = inp
        return self.model(x)






def save_checkpoint(model, optimizer, filename="models/checkpoint.pth.tar"):
    if not os.path.exists(filename.split('/')[0]):
      os.makedirs(filename.split('/')[0])
      
    print("=> Saving checkpoint")
    checkpoint = {
        "state_dict": model.state_dict(),
        "optimizer": optimizer.state_dict(),
    }
    torch.save(checkpoint, filename)


def load_checkpoint(checkpoint_file, model, optimizer, lr):
    print("=> Loading checkpoint")
    checkpoint = torch.load(checkpoint_file, map_location=DEVICE)
    model.load_state_dict(checkpoint["state_dict"])
    optimizer.load_state_dict(checkpoint["optimizer"])

    for param_group in optimizer.param_groups:
        param_group["lr"] = lr


def seed_everything(seed=42):
    os.environ["PYTHONHASHSEED"] = str(seed)
    random.seed(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed(seed)
    torch.cuda.manual_seed_all(seed)
    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.benchmark = False



def add_masked_gaussian_noise(x: torch.Tensor, mask: torch.Tensor, sigma: float) -> torch.Tensor:
    # x: [B,C,H,W], mask: [B,1,H,W]
    inv = (1.0 - mask).expand(-1, x.size(1), -1, -1)
    noise = torch.randn_like(x) * sigma
    return x + inv * noise



def james_stein_reduce(errors: torch.Tensor) -> torch.Tensor:
    mean = torch.mean(errors)
    var = torch.var(errors, unbiased=False)
    norm_sq = torch.sum((errors - mean) ** 2)
    dim = errors.numel()
    shrinkage = torch.clamp(1 - ((dim - 2) * var / (norm_sq + 1e-8)), min=0.0)
    # reduz o vetor inteiro para escalar ajustado
    return mean + shrinkage * (errors - mean).mean()

def train_fn(
    disc_H, disc_Z, gen_Z, gen_H, loader, val_loader,
    opt_disc, opt_gen, l1_loss_fn, bce_loss_fn,  # l1 e bce com reduction='none'
    d_scaler, g_scaler,
    channel, taxa, fold, pasciencia=1000
):
    best_val_loss = float('inf')
    output_buffer = ""
    epochs_no_improve = 0

    for epoch in range(NUM_EPOCHS):
        disc_H.train()
        disc_Z.train()
        gen_H.train()
        gen_Z.train()

        initial_sigma = 0.1
        final_sigma = 0.0
        warmup_epochs = 100
        sigma = max(0.0, initial_sigma * (1 - epoch / warmup_epochs))

        H_reals, H_fakes = 0.0, 0.0

        for zebra, horse, _, _, masks in loader:
            zebra = zebra.to(DEVICE)
            horse = horse.to(DEVICE)
            masks = masks.to(DEVICE).view(-1, 1, 32, 32).float()

            zebra_noisy = add_masked_gaussian_noise(zebra, masks, sigma=sigma)
            horse_noisy = add_masked_gaussian_noise(horse, masks, sigma=sigma)

            # === Treina discriminadores ===
            with torch.amp.autocast('cuda'):
                fake_horse = gen_H(zebra_noisy)
                fake_zebra = gen_Z(horse_noisy)

                # Disc_H real/fake
                D_H_real = disc_H(horse * masks, horse * masks)
                D_H_fake = disc_H(fake_horse.detach() * masks, horse * masks)

                D_H_real_loss_tensor = bce_loss_fn(D_H_real, torch.ones_like(D_H_real))
                D_H_fake_loss_tensor = bce_loss_fn(D_H_fake, torch.zeros_like(D_H_fake))
                D_H_real_loss = james_stein_reduce(D_H_real_loss_tensor.view(-1))
                D_H_fake_loss = james_stein_reduce(D_H_fake_loss_tensor.view(-1))
                D_H_loss = (D_H_real_loss + D_H_fake_loss) * 0.5

                # Disc_Z real/fake
                D_Z_real = disc_Z(zebra * masks, zebra * masks)
                D_Z_fake = disc_Z(fake_zebra.detach() * masks, zebra * masks)

                D_Z_real_loss_tensor = bce_loss_fn(D_Z_real, torch.ones_like(D_Z_real))
                D_Z_fake_loss_tensor = bce_loss_fn(D_Z_fake, torch.zeros_like(D_Z_fake))
                D_Z_real_loss = james_stein_reduce(D_Z_real_loss_tensor.view(-1))
                D_Z_fake_loss = james_stein_reduce(D_Z_fake_loss_tensor.view(-1))
                D_Z_loss = (D_Z_real_loss + D_Z_fake_loss) * 0.5

                D_loss = D_H_loss + D_Z_loss

            opt_disc.zero_grad()
            d_scaler.scale(D_loss).backward()
            d_scaler.step(opt_disc)
            d_scaler.update()

            # === Treina geradores ===
            with torch.amp.autocast('cuda'):
                D_H_fake = disc_H(fake_horse * masks, horse * masks)
                D_Z_fake = disc_Z(fake_zebra * masks, zebra * masks)

                loss_G_H_tensor = bce_loss_fn(D_H_fake, torch.ones_like(D_H_fake))
                loss_G_Z_tensor = bce_loss_fn(D_Z_fake, torch.ones_like(D_Z_fake))
                loss_G_H = james_stein_reduce(loss_G_H_tensor.view(-1))
                loss_G_Z = james_stein_reduce(loss_G_Z_tensor.view(-1))

                cycle_zebra = gen_Z(fake_horse)
                cycle_horse = gen_H(fake_zebra)

                cycle_zebra_loss_tensor = l1_loss_fn(zebra * masks, cycle_zebra * masks)
                cycle_horse_loss_tensor = l1_loss_fn(horse * masks, cycle_horse * masks)
                cycle_zebra_loss = james_stein_reduce(cycle_zebra_loss_tensor.view(-1))
                cycle_horse_loss = james_stein_reduce(cycle_horse_loss_tensor.view(-1))

                G_loss = loss_G_H + loss_G_Z + LAMBDA_CYCLE * (cycle_zebra_loss + cycle_horse_loss)

            opt_gen.zero_grad()
            g_scaler.scale(G_loss).backward()
            g_scaler.step(opt_gen)
            g_scaler.update()

            H_reals += D_H_real.mean().item()
            H_fakes += D_H_fake.mean().item()

        # === Validação ===
        disc_H.eval()
        disc_Z.eval()
        gen_H.eval()
        gen_Z.eval()

        val_H_reals, val_H_fakes = 0.0, 0.0
        val_loss_G_H_total, val_loss_G_Z_total = 0.0, 0.0

        with torch.no_grad():
            for zebra, horse, _, _, masksv in val_loader:
                zebra = zebra.to(DEVICE)
                horse = horse.to(DEVICE)
                masksv = masksv.to(DEVICE).view(-1, 1, 32, 32).float()

                zebra_noisy = add_masked_gaussian_noise(zebra, masksv, sigma=0.1)
                horse_noisy = add_masked_gaussian_noise(horse, masksv, sigma=0.1)

                fake_horse = gen_H(zebra_noisy)
                fake_zebra = gen_Z(horse_noisy)

                D_H_real = disc_H(horse * masksv, horse * masksv)
                D_H_fake = disc_H(fake_horse * masksv, horse * masksv)
                loss_G_H_tensor = bce_loss_fn(D_H_fake, torch.ones_like(D_H_fake))
                loss_G_H = james_stein_reduce(loss_G_H_tensor.view(-1))

                D_Z_real = disc_Z(zebra * masksv, zebra * masksv)
                D_Z_fake = disc_Z(fake_zebra * masksv, zebra * masksv)
                loss_G_Z_tensor = bce_loss_fn(D_Z_fake, torch.ones_like(D_Z_fake))
                loss_G_Z = james_stein_reduce(loss_G_Z_tensor.view(-1))

                val_H_reals += D_H_real.mean().item()
                val_H_fakes += D_H_fake.mean().item()
                val_loss_G_H_total += loss_G_H.item()
                val_loss_G_Z_total += loss_G_Z.item()

        val_H_reals /= len(val_loader)
        val_H_fakes /= len(val_loader)
        val_loss_G_H_total /= len(val_loader)
        val_loss_G_Z_total /= len(val_loader)

        if epoch % 20 == 0:
            output_buffer += (
                f"epoch: {epoch} "
                f"Val H_real: {val_H_reals:.4f} Val H_fake: {val_H_fakes:.4f} "
                f"Val Loss G_H: {val_loss_G_H_total:.4f} Val Loss G_Z: {val_loss_G_Z_total:.4f}\n"
            )
            sys.stdout.write(output_buffer)
            output_buffer = ""

        if val_loss_G_Z_total < best_val_loss:
            best_val_loss = val_loss_G_Z_total
            epochs_no_improve = 0
        else:
            epochs_no_improve += 1

        if epochs_no_improve >= pasciencia:
            print("Early stopping ativado!")
            break

    save_dir = f"./models_saved/cyclegan/{channel}/{taxa}/fold{fold}"
    os.makedirs(save_dir, exist_ok=True)

    model_path = os.path.join(save_dir, "generator.pth")
    torch.save(gen_Z.state_dict(), model_path)
    print(f"Save model: {model_path}")

    return gen_Z


def test(gen_Z, test_loader, taxa, fold, chanells):
		DEVICE = "cuda:1" if torch.cuda.is_available() else "cpu"

		gen_Z.eval()
		
		with torch.no_grad():
			# Criar o DataFrame com as colunas desejadas
			df = pd.DataFrame([], columns=['mae', 'asmape' ,'mape', 'rmse', 'scale'], index=test_loader.dataset.horse_images)

			for (zebra, horse, std_val, mean_val, mask), name in zip(test_loader, test_loader.dataset.horse_images):
				# Verificar as dimensões das entradas
				
				# Mover dados para o dispositivo
				zebra = zebra.to(DEVICE)
				horse = horse.to(DEVICE)

				# Converter std_val e mean_val para tensores e movê-los para o dispositivo
				std_val = torch.tensor(std_val, device=DEVICE) if not isinstance(std_val, torch.Tensor) else std_val.to(DEVICE)
				mean_val = torch.tensor(mean_val, device=DEVICE) if not isinstance(mean_val, torch.Tensor) else mean_val.to(DEVICE)

				# Gerar fake_zebra usando o gerador
				fake_zebra = gen_Z(horse)

				# Mover apenas as imagens para a CPU antes de operações subsequentes
				zebra = zebra.cpu()
				fake_zebra = fake_zebra.cpu()

				# Voltar para escala original 
				zebra = zebra * std_val.cpu() + mean_val.cpu()
				fake_zebra = fake_zebra * std_val.cpu() + mean_val.cpu()

				# Somar sobre o canal e achatar as imagens
				zebra = torch.sum(zebra, dim=1).flatten()
				fake_zebra = torch.sum(fake_zebra, dim=1).flatten()

				# Calcular as métricas
				zebra_np = zebra*mask
				fake_zebra_np = fake_zebra*mask

				# Calcular as métricas corretamente
				mae_value = round(mae(zebra_np, fake_zebra_np), 3)
				mape_value = round(mape(zebra_np, fake_zebra_np) * 100, 3)
				rmse_value = round(np.sqrt(mse(zebra_np, fake_zebra_np)), 3)
				smape_value = round(asmape(zebra_np, fake_zebra_np,mask=mask), 3)
				# Adicionar os resultados ao DataFram
				df.loc[name] = [mae_value, smape_value, mape_value, rmse_value, np.max(zebra.numpy()) - np.min(zebra.numpy())]

			# Salvar o DataFrame em um arquivo CSV
			directory =  "./resultados/resultados_ciclegan"
			if not os.path.exists(directory):
					os.makedirs(directory)

			df.to_csv(os.path.join(directory, f'result_{str(chanells)}c_{taxa}_{fold}.csv'))               



DEVICE = "cuda:0" if torch.cuda.is_available() else "cpu"
print(f"using device: {DEVICE}")

import os

TRAIN_DIR = os.path.abspath("../dataset_final")  
VAL_DIR = os.path.abspath("../dataset_final")  
INDEX_TRAIN = os.path.abspath("../dataset_final")  
INDEX_VAL = os.path.abspath("../dataset_final")  
INDEX_TEST = os.path.abspath("../dataset_final")  


BATCH_SIZE = 32
LEARNING_RATE = 1e-6
LAMBDA_IDENTITY = 0.0 # loss weight for identity loss
LAMBDA_CYCLE = 10
NUM_WORKERS = 0
NUM_EPOCHS = 50000
LOAD_MODEL = False
SAVE_MODEL = True


def main(in_channels):
    for taxa in ['10', '20', '30', '40']:
        for fold in ['1', '2', '3', '4', '5']:
            disc_H = PatchDiscriminator(input_channels=in_channels, norm_type='instancenorm').to(DEVICE)
            disc_Z = PatchDiscriminator(input_channels=in_channels, norm_type='instancenorm').to(DEVICE)
            gen_Z = UNetGenerator(input_channels=in_channels, output_channels=in_channels, norm_type='instancenorm').to(DEVICE)
            gen_H = UNetGenerator(input_channels=in_channels, output_channels=in_channels, norm_type='instancenorm').to(DEVICE)

            opt_disc = optim.Adam(
                list(disc_H.parameters()) + list(disc_Z.parameters()),
                lr=LEARNING_RATE,
                betas=(0.5, 0.999),
            )

            opt_gen = optim.Adam(
                list(gen_Z.parameters()) + list(gen_H.parameters()),
                lr=LEARNING_RATE,
                betas=(0.5, 0.999),
            )

            # Perdas com reduction='none' para usar James-Stein na redução
            l1_loss_fn = nn.L1Loss(reduction='none')
            bce_loss_fn = nn.BCEWithLogitsLoss(reduction='none')

            dataset = LoaderDataset(
                root_zebra=os.path.join(TRAIN_DIR, "label", str(taxa), "folds", f"fold{fold}", "train"),
                root_horse=os.path.join(TRAIN_DIR, "input", str(taxa), "folds", f"fold{fold}", "train"),
                root_masks=os.path.join(INDEX_TRAIN, "input", str(taxa), "folds", f"fold{fold}", "index_train"),
                chanels=in_channels
            )

            val_dataset = LoaderDataset(
                root_zebra=os.path.join(VAL_DIR, "label", str(taxa), "folds", f"fold{fold}", "val"),
                root_horse=os.path.join(VAL_DIR, "input", str(taxa), "folds", f"fold{fold}", "val"),
                root_masks=os.path.join(INDEX_VAL, "input", str(taxa), "folds", f"fold{fold}", "index_val"),
                chanels=in_channels
            )

            val_loader = DataLoader(
                val_dataset,
                batch_size=BATCH_SIZE,
                shuffle=False,
                pin_memory=False,
                num_workers=NUM_WORKERS,
            )

            loader = DataLoader(
                dataset,
                batch_size=BATCH_SIZE,
                shuffle=True,
                num_workers=NUM_WORKERS,
                pin_memory=False,
            )

            g_scaler = torch.amp.GradScaler(True)
            d_scaler = torch.amp.GradScaler(True)

            gen_Z = train_fn(
                disc_H,
                disc_Z,
                gen_Z,
                gen_H,
                loader,
                val_loader,
                opt_disc,
                opt_gen,
                l1_loss_fn,
                bce_loss_fn,
                d_scaler,
                g_scaler,
                in_channels,
                taxa,
                fold,
                pasciencia=200
            )


if __name__ == '__main__':
    for i in [1,2,3]:
        print(f'channels:{i}')
        main(in_channels=i)


using device: cuda:0
channels:1
epoch: 0 Val H_real: -0.0644 Val H_fake: -0.0722 Val Loss G_H: 0.7381 Val Loss G_Z: 0.7807
epoch: 20 Val H_real: -0.0154 Val H_fake: -0.0169 Val Loss G_H: 0.7081 Val Loss G_Z: 0.7103
epoch: 40 Val H_real: -0.0320 Val H_fake: -0.0326 Val Loss G_H: 0.7154 Val Loss G_Z: 0.7143
epoch: 60 Val H_real: -0.0482 Val H_fake: -0.0440 Val Loss G_H: 0.7210 Val Loss G_Z: 0.7189
epoch: 80 Val H_real: -0.0670 Val H_fake: -0.0585 Val Loss G_H: 0.7284 Val Loss G_Z: 0.7294
epoch: 100 Val H_real: -0.0780 Val H_fake: -0.0698 Val Loss G_H: 0.7343 Val Loss G_Z: 0.7383
epoch: 120 Val H_real: -0.0931 Val H_fake: -0.0842 Val Loss G_H: 0.7418 Val Loss G_Z: 0.7454
epoch: 140 Val H_real: -0.1080 Val H_fake: -0.0992 Val Loss G_H: 0.7498 Val Loss G_Z: 0.7550
epoch: 160 Val H_real: -0.1231 Val H_fake: -0.1174 Val Loss G_H: 0.7595 Val Loss G_Z: 0.7628
epoch: 180 Val H_real: -0.1476 Val H_fake: -0.1438 Val Loss G_H: 0.7738 Val Loss G_Z: 0.7754
epoch: 200 Val H_real: -0.1677 Val H_fake: -