<a href="https://colab.research.google.com/github/FGalvao77/GAN-com-Torch/blob/main/GAN_com_Torch.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

# **GAN com `Torch`**

---



In [2]:
# importando as bibliotecas e funções
import torch
import torch.nn as nn 
import torch.optim as optim

import torchvision
import torchvision.datasets as datasets
import torchvision.transforms as transforms

from torch.utils.data import DataLoader
from torch.utils.tensorboard import SummaryWriter   # para imprimir no tensorboard

In [5]:
# definindo arquitetura do discriminador
class Discriminator(nn.Module):
    def __init__(self, img_dim):
        super().__init__()
        self.discriminator = nn.Sequential(
            nn.Linear(img_dim, 128),
            nn.LeakyReLU(0.2),
            nn.Linear(128, 1), 
            nn.Sigmoid()
        )

    def forward(self, x):
        return self.discriminator(x)

In [6]:
# definindo arquitetura do gerador
class Generator(nn.Module):
    def __init__(self, z_dim, img_dim):
        super().__init__()
        self.generator = nn.Sequential(
            nn.Linear(z_dim, 256), 
            nn.LeakyReLU(0.2),
            nn.Linear(256, img_dim),    
            nn.Tanh()   # normalizando as entradas para [-1, 1] para gerar as saídas [-1, 1]
        )

    def forward(self, x):
        return self.generator(x)

In [12]:
# definindo os hiperparâmetros
device = 'cuda' if torch.cuda.is_available() else 'cpu'
lr = 3e-4
z_dim = 64  # 128, 156
image_dim = 28 * 28 * 1   # 784
batch_size = 32
num_epochs = 50

In [13]:
# instanciando os modelos
model_discriminator = Discriminator(image_dim).to(device)
model_generator = Generator(z_dim, image_dim).to(device)

# instanciando o ruído do modelo
fixed_noise = torch.randn(batch_size, z_dim).to(device)

In [14]:
# definindo a normalização dos dados
transforms = transforms.Compose(
    [transforms.ToTensor(), transforms.Normalize(
        (0.5), (0.5)
    )]
)

In [15]:
# instanciando o conjunto de dados que será utilizado na pasta "dataset"
dataset = datasets.MNIST(root='dataset/', transform=transforms, download=True)

Downloading http://yann.lecun.com/exdb/mnist/train-images-idx3-ubyte.gz
Downloading http://yann.lecun.com/exdb/mnist/train-images-idx3-ubyte.gz to dataset/MNIST/raw/train-images-idx3-ubyte.gz


  0%|          | 0/9912422 [00:00<?, ?it/s]

Extracting dataset/MNIST/raw/train-images-idx3-ubyte.gz to dataset/MNIST/raw

Downloading http://yann.lecun.com/exdb/mnist/train-labels-idx1-ubyte.gz
Downloading http://yann.lecun.com/exdb/mnist/train-labels-idx1-ubyte.gz to dataset/MNIST/raw/train-labels-idx1-ubyte.gz


  0%|          | 0/28881 [00:00<?, ?it/s]

Extracting dataset/MNIST/raw/train-labels-idx1-ubyte.gz to dataset/MNIST/raw

Downloading http://yann.lecun.com/exdb/mnist/t10k-images-idx3-ubyte.gz
Downloading http://yann.lecun.com/exdb/mnist/t10k-images-idx3-ubyte.gz to dataset/MNIST/raw/t10k-images-idx3-ubyte.gz


  0%|          | 0/1648877 [00:00<?, ?it/s]

Extracting dataset/MNIST/raw/t10k-images-idx3-ubyte.gz to dataset/MNIST/raw

Downloading http://yann.lecun.com/exdb/mnist/t10k-labels-idx1-ubyte.gz
Downloading http://yann.lecun.com/exdb/mnist/t10k-labels-idx1-ubyte.gz to dataset/MNIST/raw/t10k-labels-idx1-ubyte.gz


  0%|          | 0/4542 [00:00<?, ?it/s]

Extracting dataset/MNIST/raw/t10k-labels-idx1-ubyte.gz to dataset/MNIST/raw



In [17]:
# instanciando o "loader"
loader = DataLoader(dataset, batch_size=batch_size, shuffle=True)

In [18]:
# instanciando o otimizador dos modelos
opt_discriminator = optim.Adam(model_discriminator.parameters(), lr=lr)
opt_generator = optim.Adam(model_generator.parameters(), lr=lr)

In [19]:
# instanciando a função de perda
criterion = nn.BCELoss()

In [20]:
# definindo as pastas para guardar os resultados gerado pelos modelos
writer_fake = SummaryWriter(f'runs/GAN_MNIST/fake')
writer_real = SummaryWriter(f'runs/GAN_MNIST/real')

In [37]:
# parametrizando o treinamento da rede
step = 0

for epoch in range(num_epochs):
    for batch_idx, (real, _) in enumerate(loader):
        real = real.view(-1, 784).to(device)
        batch_size = real.shape[0]

        # treinando o Discriminator: max log(D(real)) + log(1 - D(G(z)))
        noise = torch.randn(batch_size, z_dim).to(device)
        fake = model_generator(noise)
        discriminator_real = model_discriminator(real).view(-1)
        lossD_real = criterion(discriminator_real, torch.ones_like(discriminator_real))
        discriminator_fake = model_discriminator(fake).view(-1)
        lossD_fake = criterion(discriminator_fake, torch.zeros_like(discriminator_fake))
        lossD = (lossD_real + lossD_fake) / 2
        model_discriminator.zero_grad()
        lossD.backward(retain_graph=True)
        opt_discriminator.step()

        # treinando o Generator: min log(1 - D(G(z))) <----> max log(D(G(z)))
        # a segunda opção de maximização não sofre de saturação dos gradientes
        output = model_discriminator(fake).view(-1)
        lossG = criterion(output, torch.ones_like(output))
        model_generator.zero_grad()
        lossG.backward()
        opt_generator.step()

        if batch_idx == 0:
            print(
                f'Epoch [{epoch}/{num_epochs}] | Batch {batch_idx}/{len(loader)}\
                Loss D: {lossD:.4f}, Loss G: {lossG:.4f}'
            )

            with torch.no_grad():
                fake = model_generator(fixed_noise).reshape(-1, 1, 28 ,28)
                data = real.reshape(-1, 1, 28, 28)

                img_grid_fake = torchvision.utils.make_grid(fake, normalize=True)
                img_grid_real = torchvision.utils.make_grid(data, normalize=True)

                writer_fake.add_image(
                    'MNist Fake Imagens', img_grid_fake, global_step=step
                )

                writer_real.add_image(
                    'MNist Real Imagens', img_grid_real, global_step=step
                )
                 
                step =+ 1

Epoch [0/50] | Batch 0/1875                Loss D: 0.5948, Loss G: 0.9447
Epoch [1/50] | Batch 0/1875                Loss D: 0.5650, Loss G: 0.9836
Epoch [2/50] | Batch 0/1875                Loss D: 0.6284, Loss G: 0.8983
Epoch [3/50] | Batch 0/1875                Loss D: 0.6933, Loss G: 0.7486
Epoch [4/50] | Batch 0/1875                Loss D: 0.6357, Loss G: 0.7663
Epoch [5/50] | Batch 0/1875                Loss D: 0.6151, Loss G: 1.0822
Epoch [6/50] | Batch 0/1875                Loss D: 0.7011, Loss G: 0.9672
Epoch [7/50] | Batch 0/1875                Loss D: 0.6718, Loss G: 0.6944
Epoch [8/50] | Batch 0/1875                Loss D: 0.6880, Loss G: 0.8440
Epoch [9/50] | Batch 0/1875                Loss D: 0.6014, Loss G: 0.8457
Epoch [10/50] | Batch 0/1875                Loss D: 0.7164, Loss G: 0.8056
Epoch [11/50] | Batch 0/1875                Loss D: 0.6835, Loss G: 0.8391
Epoch [12/50] | Batch 0/1875                Loss D: 0.5281, Loss G: 0.9644
Epoch [13/50] | Batch 0/1875       

**Things to try:**
- 1. What happens if you use larger network?
- 2. Better normalization with BatchNorm?
- 3. Different learning rate (is there a better one)?
- 4. Change architecture to a CNN?

**Coisas para tentar:**
- 1. O que acontece se você usar uma rede maior?
- 2. Melhor normalização com BatchNorm?
- 3. Taxa de aprendizagem diferente (existe uma melhor)?
- 4. Mudar a arquitetura para uma CNN?

In [44]:
# %load_ext tensorboard

In [45]:
# # importando biblioteca para registro de tempo
# import datetime as dt
# import tensorflow as tf

# # instanciando o log 
# !rm -rf ./logs/
# log_dir = 'logs/fit/' + dt.datetime.now().strftime('%Y%m%d-%H%M%S') 
# tensorboard_callback = tf.keras.callbacks.TensorBoard(log_dir=log_dir)

In [46]:
# %tensorboard --logdir logs/fit