<a href="https://colab.research.google.com/github/Wemerson-ferr/TCC_SAGAN_IMPLEMENTATION/blob/main/SAGAN_MULT_TEST.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [None]:
!pip install kaggle --quiet
!pip install torch-fidelity --quiet
import sys
import os
import torch
import torch.nn as nn
import torch.optim as optim
import torch.autograd as autograd
import torchvision
import torchvision.utils as vutils
import torchvision.transforms as transforms
import matplotlib.pyplot as plt
import json
import csv
import torch_fidelity
from torch_fidelity.metrics import calculate_metrics
from torch.nn.utils import spectral_norm
from torch.utils.data import DataLoader
from google.colab import drive

In [None]:
# Conectando ao Drive, para ler/escrever arquivos de execução.
drive.mount('/content/drive')

# --- LÓGICA PARA CRIAR PASTA DE EXECUÇÃO ÚNICA ---
BASE_PROJECT_PATH = '/content/drive/MyDrive/Meu_TCC_SAGAN_V2'
os.makedirs(BASE_PROJECT_PATH, exist_ok=True)

In [None]:
# Função para gerar e salvar imagens (pode ser mantida fora da função principal)
def save_fake_images_for_evaluation(netG, device, save_path, num_images=5000):
    print(f"Gerando {num_images} imagens para avaliação em '{save_path}'...")
    os.makedirs(save_path, exist_ok=True)
    batch_size = 64
    nz = 100 # Supondo nz=100
    for i in range(0, num_images, batch_size):
        with torch.no_grad():
            noise = torch.randn(batch_size, nz, 1, 1, device=device)
            fake_images = netG(noise).detach().cpu()
        for j in range(fake_images.size(0)):
            img_idx = i + j
            if img_idx < num_images:
                vutils.save_image(fake_images[j], os.path.join(save_path, f'fake_{img_idx:05d}.png'), normalize=True)
    print("Geração concluída.")

In [None]:
def plot_latent_space_interpolation(netG, device, save_dir, hyperparameters, num_steps=10):
    z1 = torch.randn(1, hyperparameters["latent_dim_nz"], 1, 1, device=device)
    z2 = torch.randn(1, hyperparameters["latent_dim_nz"], 1, 1, device=device)
    alpha_values = torch.linspace(0, 1, num_steps)
    interpolated_z = torch.cat([(1 - alpha) * z1 + alpha * z2 for alpha in alpha_values], dim=0)

    with torch.no_grad():
        fake_images = netG(interpolated_z).detach().cpu()

    grid = torchvision.utils.make_grid(fake_images, nrow=num_steps, padding=2, normalize=True)
    plt.figure(figsize=(15, 3)); plt.imshow(grid.permute(1, 2, 0))
    plt.title("Interpolação no Espaço Latente"); plt.axis("off")
    plt.savefig(os.path.join(save_dir, 'interpolacao_latente.png')); plt.close()

In [None]:
# Configuração da API do Kaggle e Download do Dataset (faça isso uma vez)
KAGGLE_JSON_PATH = "/content/drive/MyDrive/Colab_Secrets/kaggle.json"
if not os.path.exists("/content/data/celebahq"):
    print("Configurando API do Kaggle e baixando o dataset...")
    !mkdir -p ~/.kaggle
    !cp "{KAGGLE_JSON_PATH}" ~/.kaggle/
    !chmod 600 ~/.kaggle/kaggle.json
    !mkdir -p /content/data
    !kaggle datasets download -d lamsimon/celebahq -p /content/data
    !unzip -q /content/data/celebahq.zip -d /content/data/celebahq
    print("Dataset pronto!")
else:
    print("Dataset já existe. Pulando o download.")

In [None]:
def run_experiment(hyperparameters, experiment_name):
    """
    Função que encapsula todo o pipeline de um experimento:
    configuração, treinamento, avaliação e logging.
    """
    # --- CRIAÇÃO DA ESTRUTURA DE PASTAS PARA O EXPERIMENTO ATUAL ---
    PROJECT_PATH = os.path.join(BASE_PROJECT_PATH, experiment_name)
    print(f"Pasta da execução definida como: {PROJECT_PATH}")

    IMAGES_PATH = os.path.join(PROJECT_PATH, 'imagens_geradas')
    CHECKPOINTS_PATH = os.path.join(PROJECT_PATH, 'checkpoints')
    LOGS_PATH = os.path.join(PROJECT_PATH, 'logs')
    EVAL_IMAGES_BASE_PATH = os.path.join(PROJECT_PATH, 'imagens_para_avaliacao') # ### MUDANÇA ###: Pasta base para imagens de avaliação

    os.makedirs(IMAGES_PATH, exist_ok=True)
    os.makedirs(CHECKPOINTS_PATH, exist_ok=True)
    os.makedirs(LOGS_PATH, exist_ok=True)
    os.makedirs(EVAL_IMAGES_BASE_PATH, exist_ok=True) # ### MUDANÇA ###: Cria a pasta base


    # Salva os hiperparâmetros em um arquivo JSON
    with open(os.path.join(PROJECT_PATH, 'hyperparameters.json'), 'w') as f:
        json.dump(hyperparameters, f, indent=4)

    # Configurações do dispositivo e transformações
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    transform = transforms.Compose([
        transforms.Resize(hyperparameters["image_size"]),
        transforms.CenterCrop(hyperparameters["image_size"]),
        transforms.ToTensor(),
        transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)),
    ])

    # Carregar o Dataset com o Dataloader
    # O download e descompactação podem ser feitos uma vez fora da função se o dataset não mudar
    dataset = torchvision.datasets.ImageFolder(root="/content/data/celebahq", transform=transform)
    dataloader = DataLoader(dataset, batch_size=hyperparameters["batch_size"], shuffle=True, num_workers=4)

    # Definições das classes SelfAttention, Generator, Discriminator e função weights_init
    class SelfAttention(nn.Module):
        def __init__(self, in_channels):
            super().__init__()
            self.query = spectral_norm(nn.Conv2d(in_channels, in_channels // 8, 1))
            self.key = spectral_norm(nn.Conv2d(in_channels, in_channels // 8, 1))
            self.value = spectral_norm(nn.Conv2d(in_channels, in_channels, 1))
            self.gamma = nn.Parameter(torch.tensor(0.0)) # Parâmetro de escala aprendível

        def forward(self, x):
            batch_size, C, H, W = x.size()
            query = self.query(x).view(batch_size, -1, H * W).permute(0, 2, 1)
            key = self.key(x).view(batch_size, -1, H * W)
            value = self.value(x).view(batch_size, -1, H * W)
            attention_matrix = torch.bmm(query, key)
            attention_map = torch.softmax(attention_matrix, dim=-1)
            out = torch.bmm(value, attention_map.permute(0, 2, 1))
            out = out.view(batch_size, C, H, W)
            return self.gamma * out + x

    nz = hyperparameters["latent_dim_nz"]
    ngf = hyperparameters["gen_feature_map_size"]
    ndf = hyperparameters["disc_feature_map_size"]
    nc = hyperparameters["num_channels"]

    class Generator(nn.Module):
        def __init__(self):
            super(Generator, self).__init__()
            self.main = nn.Sequential(
                nn.ConvTranspose2d(nz, ngf*8, 4, 1, 0, bias=False), nn.BatchNorm2d(ngf*8), nn.Mish(),
                nn.ConvTranspose2d(ngf*8, ngf*4, 4, 2, 1, bias=False), nn.BatchNorm2d(ngf*4), nn.Mish(),
                nn.ConvTranspose2d(ngf*4, ngf*2, 4, 2, 1, bias=False), nn.BatchNorm2d(ngf*2), nn.Mish(),
                nn.ConvTranspose2d(ngf*2, ngf, 4, 2, 1, bias=False), nn.BatchNorm2d(ngf), nn.Mish(),
                SelfAttention(ngf),
                nn.ConvTranspose2d(ngf, nc, 4, 2, 1, bias=False), nn.Tanh()
            )
        def forward(self, x): return self.main(x)

    class Discriminator(nn.Module):
        def __init__(self):
            super(Discriminator, self).__init__()
            self.main = nn.Sequential(
                spectral_norm(nn.Conv2d(nc, ndf, 4, 2, 1, bias=False)), nn.LeakyReLU(),
                spectral_norm(nn.Conv2d(ndf, ndf*2, 4, 2, 1, bias=False)), nn.BatchNorm2d(ndf*2), nn.LeakyReLU(),
                spectral_norm(nn.Conv2d(ndf*2, ndf*4, 4, 2, 1, bias=False)), nn.BatchNorm2d(ndf*4), nn.LeakyReLU(),
                SelfAttention(ndf*4),
                spectral_norm(nn.Conv2d(ndf*4, ndf*8, 4, 2, 1, bias=False)), nn.BatchNorm2d(ndf*8), nn.LeakyReLU(),
                spectral_norm(nn.Conv2d(ndf*8, 1, 4, 1, 0, bias=False)),
            )
        def forward(self, x): return self.main(x)

    def weights_init(m):
        classname = m.__class__.__name__
        if classname.find('Conv') != -1: nn.init.normal_(m.weight.data, 0.0, 0.02)
        elif classname.find('BatchNorm') != -1:
            nn.init.normal_(m.weight.data, 1.0, 0.02)
            nn.init.constant_(m.bias.data, 0)

    # Instanciar modelos e otimizadores
    netG = Generator().to(device)
    netD = Discriminator().to(device)
    netG.apply(weights_init)
    netD.apply(weights_init)
    optimizerD = optim.AdamW(netD.parameters(), lr=hyperparameters["learning_rate_disc"], betas=hyperparameters["optimizer_betas"], weight_decay=0.01)
    optimizerG = optim.AdamW(netG.parameters(), lr=hyperparameters["learning_rate_gen"], betas=hyperparameters["optimizer_betas"], weight_decay=0.01)

    # Salva a arquitetura dos modelos
    with open(os.path.join(PROJECT_PATH, 'model_architecture.txt'), 'w') as f:
        f.write("================== GERADOR ==================\n"); f.write(str(netG))
        f.write("\n\n================ DISCRIMINADOR ================\n"); f.write(str(netD))

    # Setup dos arquivos de log
    log_file_loss = os.path.join(LOGS_PATH, 'historico_treinamento.csv')
    log_file_metrics = os.path.join(LOGS_PATH, 'evaluation_metrics.csv') # ### MUDANÇA ###

    # Cria o arquivo de log de perdas com cabeçalho
    with open(log_file_loss, 'w', newline='') as f:
        writer = csv.writer(f)
        writer.writerow(['Epoch', 'Loss_Discriminator', 'Loss_Generator'])

    # ### MUDANÇA ###: Cria o arquivo de log de métricas com cabeçalho
    with open(log_file_metrics, 'w', newline='') as f:
        writer = csv.writer(f)
        writer.writerow(['epoch_evaluated', 'frechet_inception_distance', 'inception_score_mean', 'inception_score_std'])

    # Loop de Treinamento
    epochs = hyperparameters["total_epochs"]
    fixed_noise = torch.randn(64, nz, 1, 1, device=device)
    REAL_IMAGES_PATH = "/content/data/celebahq" # Caminho para imagens reais não muda

    print("Iniciando o treinamento...")
    for epoch in range(epochs):
        for i, (real, _) in enumerate(dataloader):
            # Lógica de treinamento do Discriminador e Gerador (sem alterações)
            real = real.to(device)
            b_size = real.size(0)
            netD.zero_grad()
            noise = torch.randn(b_size, nz, 1, 1, device=device)
            fake = netG(noise)
            real_output = netD(real).view(-1)
            fake_output = netD(fake.detach()).view(-1)
            errD_real = torch.mean(nn.ReLU(inplace=True)(1.0 - real_output))
            errD_fake = torch.mean(nn.ReLU(inplace=True)(1.0 + fake_output))
            errD = errD_real + errD_fake
            errD.backward()
            optimizerD.step()

            for _ in range(2):
              netG.zero_grad()
              noise = torch.randn(b_size, nz, 1, 1, device=device)
              fake = netG(noise)
              output = netD(fake).view(-1)
              errG = -torch.mean(output)
              errG.backward()
              optimizerG.step()

        # Log de perdas a cada época
        print(f"[{epoch+1}/{epochs}] Loss_D: {errD.item():.4f} Loss_G: {errG.item():.4f}")
        with open(log_file_loss, 'a', newline='') as f:
            writer = csv.writer(f)
            writer.writerow([epoch+1, errD.item(), errG.item()])

        # Salva imagens de exemplo e checkpoints em intervalos
        if (epoch+1) % 10 == 0:
            with torch.no_grad():
                fake_imgs = netG(fixed_noise).detach().cpu()
                grid = torchvision.utils.make_grid(fake_imgs, padding=2, normalize=True)
                plt.figure(figsize=(8,8)); plt.imshow(grid.permute(1, 2, 0))
                plt.title(f"Imagens Geradas na Época {epoch+1}"); plt.axis("off")
                plt.savefig(os.path.join(IMAGES_PATH, f'epoca_{epoch+1:04d}.png')); plt.close()

            checkpoint_path = os.path.join(CHECKPOINTS_PATH, f'checkpoint_epoca_{epoch+1}.pth')
            torch.save({'epoch': epoch, 'netG_state_dict': netG.state_dict()}, checkpoint_path)
            print(f"Checkpoint da época {epoch+1} salvo.")


        # ### MUDANÇA 3: AVALIAÇÃO PERIÓDICA DAS MÉTRICAS ###
        # Avalia na primeira época, na última, e nos intervalos definidos
        if (epoch + 1) % hyperparameters["evaluation_interval"] == 0 or (epoch + 1) == epochs or epoch == 0:
            print(f"\n--- Realizando avaliação na época {epoch+1} ---")

            # Caminho específico para as imagens desta avaliação
            current_eval_path = os.path.join(EVAL_IMAGES_BASE_PATH, f"epoca_{epoch+1}")
            save_fake_images_for_evaluation(netG, device, current_eval_path, num_images=2000)

            # Calcular métricas
            metrics = calculate_metrics(
                input1=REAL_IMAGES_PATH,
                input2=current_eval_path,
                cuda=True, isc=True, fid=True, verbose=False,
                samples_find_deep=True
            )

            # Salvar métricas no arquivo CSV
            with open(log_file_metrics, 'a', newline='') as f:
                writer = csv.writer(f)
                row = [epoch + 1] + list(metrics.values())
                writer.writerow(row)

            print(f"--- Avaliação da Época {epoch+1} concluída. FID: {metrics.get('frechet_inception_distance'):.2f}, IS: {metrics.get('inception_score_mean'):.2f} ---\n")


    print("Treinamento concluído!")
    # Plot da interpolação no final
    plot_latent_space_interpolation(netG, device, IMAGES_PATH, hyperparameters)


In [None]:
experiment_1_feature_map_128 = {
    "total_epochs": 150,
    "evaluation_interval": 25,
    "image_size": 64,
    "batch_size": 128,
    "latent_dim_nz": 100,
    "gen_feature_map_size": 128,
    "disc_feature_map_size": 128,
    "num_channels": 3,
    "learning_rate_gen": 0.00002,
    "learning_rate_disc": 0.00001,
    "optimizer_betas": (0.5, 0.999),
    "dataset": "CelebA-HQ",
    "notes": "Execução com feature_map=128 para testar impacto."
}

experiment_2_feature_map_256 = {
    "total_epochs": 150,
    "evaluation_interval": 25,
    "image_size": 64,
    "batch_size": 128,
    "latent_dim_nz": 100,
    "gen_feature_map_size": 256,
    "disc_feature_map_size": 256,
    "num_channels": 3,
    "learning_rate_gen": 0.00002,
    "learning_rate_disc": 0.00001,
    "optimizer_betas": (0.5, 0.999),
    "dataset": "CelebA-HQ",
    "notes": "Execução com feature_map=256 e nn.Mish. Avaliação a cada 25 epocas."
}

experiment_2_feature_map_256 = {
    "total_epochs": 150,
    "evaluation_interval": 25,
    "image_size": 64,
    "batch_size": 128,
    "latent_dim_nz": 100,
    "gen_feature_map_size": 512,
    "disc_feature_map_size": 512,
    "num_channels": 3,
    "learning_rate_gen": 0.00002,
    "learning_rate_disc": 0.00001,
    "optimizer_betas": (0.5, 0.999),
    "dataset": "CelebA-HQ",
    "notes": "Execução com feature_map=256 e nn.Mish. Avaliação a cada 25 epocas."
}


experiments_to_run = {
    "Experimento_Mish_FM256": experiment_1_feature_map_256,
    "Experimento_Mish_FM128": experiment_2_feature_map_128,
}

In [None]:
# Loop principal que executa cada experimento definido
for name, params in experiments_to_run.items():
    print(f"\n========================================================")
    print(f"INICIANDO EXPERIMENTO: {name}")
    print(f"========================================================")
    run_experiment(hyperparameters=params, experiment_name=name)
    print(f"\n========================================================")
    print(f"EXPERIMENTO {name} CONCLUÍDO")
    print(f"========================================================")
