# Imports

In [1]:
import torch
import torch.nn as nn 
import torch.optim as optim 

import torchvision
import torchvision.datasets as datasets
import torchvision.transforms as transforms

from torch.utils.data import DataLoader
from torch.utils.tensorboard import SummaryWriter

# Model architecture

In [145]:
class Discriminator(nn.Module):
    def __init__(self, img_dimension):
        super().__init__()
        self.disc = nn.Sequential(
            nn.Linear(img_dimension, 128),
            nn.LeakyReLU(0.1),
            nn.Linear(128, 1),
            nn.Sigmoid()
        )
    
    def forward(self, x):
        return self.disc(x)

In [146]:
class Generator(nn.Module):
    def __init__(self, noise_dimension, img_dimension):
        super().__init__()
        self.gen = nn.Sequential(
            nn.Linear(noise_dimension, 256),
            nn.LeakyReLU(0.1),
            nn.Linear(256, 128),
            nn.LeakyReLU(0.1),
            nn.Linear(128, img_dimension),
            nn.Tanh()
        )
    
    def forward(self, x):
        return self.gen(x)


In [147]:
    def __init__(self, noise_dimension, img_dimension):
        super().__init__()
        self.gen = nn.Sequential(
            nn.Linear(noise_dimension, 256),
            nn.LeakyReLU(0.1),
            nn.Linear(256, img_dimension),
            nn.Tanh() # Queremos que os valores das imagens esteja entre -1 e 1, porque o mnist est√° entre -1 e 1
        )
    
    def forward(self, x):
        return self.gen(x)


# Hyperparams

In [148]:
device = "cuda" if torch.cuda.is_available() else "cpu"
lr = 3e-4 
noise_dim = 64
image_dim = 28 * 28 * 1
batch_size = 32
epochs = 50
step = 0

# Data

In [None]:
df = datasets.MNIST(root="dataset/", transform=transforms, download=True)
loader = DataLoader(df, batch_size=batch_size, shuffle=True) 

# Model load and train

In [149]:
disc = Discriminator(image_dim).to(device)
gen = Generator(noise_dim, image_dim).to(device)

fixed_noise = torch.randn((batch_size, noise_dim)).to(device) 

transforms = transforms.Compose(
    [transforms.ToTensor(), transforms.Normalize((0.5,), (0.5))]
)

In [150]:

opt_disc = optim.Adam(disc.parameters(), lr=lr)
opt_gen = optim.Adam(gen.parameters(), lr=lr)

In [151]:
criterion = nn.BCELoss()
writer_fake = SummaryWriter(f'runs/GAN_MNIST/fake')
writer_real = SummaryWriter(f'runs/GAN_MNIST/real')

In [152]:
for epoch in range(epochs):
    for batch_idx, (real, _) in enumerate(loader):
        real = real.view(-1, 784).to(device) # flatten
        batch_size = real.shape[0]

        ## Generates the noise
        noise = torch.randn(batch_size, noise_dim).to(device)
        fake = gen(noise)

        ## Discriminator
        disc_real = disc(real).view(-1)
        lossD_real = criterion(disc_real, torch.ones_like(disc_real))

        disc_fake = disc(fake.detach()).view(-1)
        lossD_fake = criterion(disc_fake, torch.zeros_like(disc_fake))

        lossD = (lossD_real + lossD_fake) / 2

        disc.zero_grad()
        lossD.backward()
        opt_disc.step()

        ## Genator
        output = disc(fake).view(-1)
        lossG = criterion(output, torch.ones_like(output))
        gen.zero_grad()
        lossG.backward()
        opt_gen.step()

        if batch_idx == 0:
            print(
                f"Epoch [{epoch}/{epochs}] Batch {batch_idx}/{len(loader)} \
                      Loss D: {lossD:.4f}, loss G: {lossG:.4f}"
            )

            with torch.no_grad():
                fake = gen(fixed_noise).reshape(-1, 1, 28, 28)
                data = real.reshape(-1, 1, 28, 28)
                img_grid_fake = torchvision.utils.make_grid(fake, normalize=True)
                img_grid_real = torchvision.utils.make_grid(data, normalize=True)

                writer_fake.add_image(
                    "Mnist Fake Images", img_grid_fake, global_step=step
                )
                writer_real.add_image(
                    "Mnist Real Images", img_grid_real, global_step=step
                )
                step += 1
    

Epoch [0/50] Batch 0/1875                       Loss D: 0.6427, loss G: 0.6828
Epoch [1/50] Batch 0/1875                       Loss D: 0.3440, loss G: 1.3385
Epoch [2/50] Batch 0/1875                       Loss D: 0.4925, loss G: 1.3491
Epoch [3/50] Batch 0/1875                       Loss D: 0.5468, loss G: 1.0681
Epoch [4/50] Batch 0/1875                       Loss D: 0.5967, loss G: 1.1499
Epoch [5/50] Batch 0/1875                       Loss D: 0.6295, loss G: 1.0062
Epoch [6/50] Batch 0/1875                       Loss D: 0.6767, loss G: 0.9817
Epoch [7/50] Batch 0/1875                       Loss D: 1.1963, loss G: 0.6574
Epoch [8/50] Batch 0/1875                       Loss D: 0.8996, loss G: 0.7390
Epoch [9/50] Batch 0/1875                       Loss D: 0.6270, loss G: 1.1691
Epoch [10/50] Batch 0/1875                       Loss D: 0.5382, loss G: 1.0862
Epoch [11/50] Batch 0/1875                       Loss D: 0.6608, loss G: 1.0290
Epoch [12/50] Batch 0/1875                       L